# Budget Policy Setup Notebook

This notebook programmatically creates budget policies for each operational category in the Stellar Analytics Division. It implements:

- Resource tagging via Databricks REST API
- Unity Catalog object tagging via SQL
- Budget policy creation with category-specific limits
- Alert threshold configuration
- Cluster policy enforcement

**Author**: SAD Analytics Team  
**Version**: 1.0  
**Last Updated**: April 2025

In [0]:
pip install python-dotenv

In [0]:
# Load environment variables
from dotenv import load_dotenv
import os
    
load_dotenv()

TOKEN = os.getenv("TOKEN")
DATABRICKS_INSTANCE = os.getenv("DATABRICKS_INSTANCE")
CLUSTER_ID = os.getenv("CLUSTER_ID")
WAREHOUSE_ID = os.getenv("WAREHOUSE_ID")
ACCOUNT_ID = os.getenv("ACCOUNT_ID")  # Required for budget policy APIs

print(f"TOKEN: {TOKEN}")
print(f"DATABRICKS_INSTANCE: {DATABRICKS_INSTANCE}")
print(f"CLUSTER_ID: {CLUSTER_ID}")
print(f"WAREHOUSE_ID: {WAREHOUSE_ID}")
print(f"ACCOUNT_ID: {ACCOUNT_ID}")

In [0]:
# Parameters

# Databricks instance configuration
DATABRICKS_INSTANCE = DATABRICKS_INSTANCE
TOKEN = TOKEN
ACCOUNT_ID = ACCOUNT_ID  # Required for budget policy APIs

# Division-wide configuration
DIVISION = "SAD"  # Stellar Analytics Division
COST_CENTER = "ND-Analytics"

# Operational categories and their monthly budgets (in USD)
OPERATIONAL_CATEGORIES = {
    "DEEP_SPACE_TELEMETRY": 10000,
    "PROPULSION_ANALYTICS": 12000,
    "ORBITAL_MECHANICS": 8000,
    "MATERIALS_SCIENCE": 9000,
    "EXPLORATORY_MISSIONS": 15000,
    "NAVIGATION_SYSTEMS": 7000,
    "EXOPLANET_RESEARCH": 11000
}

# Email notification list
ADMIN_EMAILS = [
    "shawndeggans@gmail.com"
]

# If ACCOUNT_ID is not set, provide a default value or show a warning
if not ACCOUNT_ID:
    print("WARNING: ACCOUNT_ID is not set. This is required for budget policy APIs.")
    print("Please set the ACCOUNT_ID environment variable or update the ACCOUNT_ID value directly.")
    # You could set a default here if appropriate
    # ACCOUNT_ID = "your-account-id"

In [0]:
# Imports
import requests
import json
import logging
import uuid
from datetime import datetime
from pyspark.sql import SparkSession

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("BudgetPolicySetup")

# Initialize Spark session for SQL operations if needed
spark = SparkSession.builder.appName("BudgetPolicySetup").getOrCreate()

# Databricks API base URLs
API_BASE_URL = f"https://{DATABRICKS_INSTANCE}/api/2.0"
BUDGET_POLICY_BASE_URL = f"https://{DATABRICKS_INSTANCE}/api/2.1/accounts/{ACCOUNT_ID}/budget-policies"
CLUSTER_POLICY_URL = f"{API_BASE_URL}/policies/clusters"

In [0]:
# Execution: Create budget policies for all operational categories
def setup_all_policies():
    """Set up budget policies for all operational categories."""
    
    results = {}
    
    for category, monthly_budget in OPERATIONAL_CATEGORIES.items():
        logger.info(f"\n{'='*20} Setting up {category} {'='*20}")
        
        # Create tags for the category
        tags = {
            "OperationalCategory": category,
            "Team": DIVISION,
            "CostCenter": COST_CENTER,
            "Budget": str(monthly_budget),
            "Environment": "PROD"
        }
        
        # Format tags for budget policy (list of dictionaries)
        budget_policy_tags = [{"key": k, "value": v} for k, v in tags_dict.items()]
        
        # Format the budget policy name
        policy_name = f"{category.lower()}_budget_policy"
        
        # Create budget policy using account API
        budget_policy_id = create_budget_policy(
            name=policy_name,
            custom_tags=budget_policy_tags,
            token=TOKEN
        )
        
        # Create cluster policy
        limits = CATEGORY_LIMITS[category]
        cluster_policy_id = create_cluster_policy(
            category, 
            limits["instances"], 
            limits["max_workers"], 
            TOKEN
        )
            
        results[category] = {
            "budget_policy_id": budget_policy_id,
            "cluster_policy_id": cluster_policy_id,
            "budget_usd": monthly_budget
        }
    
    return results

# Execute the setup
setup_results = setup_all_policies()

# Display results
print("\n" + "="*50)
print("Policy Setup Results")
print("="*50)
for category, result in setup_results.items():
    print(f"\n{category}:")
    print(f"  Budget: ${OPERATIONAL_CATEGORIES[category]:,}")
    print(f"  Budget Policy ID: {result.get('budget_policy_id', 'N/A')}")
    print(f"  Cluster Policy ID: {result.get('cluster_policy_id', 'N/A')}")


In [0]:
# Validate deployment

def validate_policies():
    """Validate that all policies are properly deployed."""
    print("\n" + "="*50)
    print("Policy Validation Report")
    print("="*50)
    
    # Check for existing serverless budget policies
    serverless_policies = get_serverless_budget_policies(TOKEN)
    serverless_policy_names = {p["name"]: p["policy_id"] for p in serverless_policies}
    
    # Check for existing cluster policies
    url = f"{CLUSTER_POLICY_URL}/list"
    headers = create_databricks_headers(TOKEN)
    cluster_response = requests.get(url, headers=headers)
    
    if cluster_response.status_code == 200:
        cluster_policies = cluster_response.json().get("policies", [])
        cluster_policy_names = {p["name"]: p["policy_id"] for p in cluster_policies}
        
        print("\nServerless Budget Policies:")
        for category in OPERATIONAL_CATEGORIES.keys():
            policy_name = f"{category.lower()}_budget_policy"
            policy_exists = policy_name in serverless_policy_names
            
            status = "✓" if policy_exists else "✗"
            print(f"{status} {category}: {'Policy deployed' if policy_exists else 'Policy missing'}")
        
        print("\nCluster Policies:")
        for category in OPERATIONAL_CATEGORIES.keys():
            policy_name = f"{category.lower()}_cluster_policy"
            policy_exists = policy_name in cluster_policy_names
            
            status = "✓" if policy_exists else "✗"
            print(f"{status} {category}: {'Policy deployed' if policy_exists else 'Policy missing'}")
    else:
        print(f"Error validating cluster policies: {cluster_response.text}")

# Run validation
validate_policies()

In [0]:
# Resource tagging functions (REST API)

def create_databricks_headers(token):
    """Create headers for Databricks REST API calls."""
    return {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }

def tag_cluster(cluster_id, custom_tags, databricks_instance, token):
    """Tag a Databricks cluster with custom tags."""
    url = f"https://{databricks_instance}/api/2.0/clusters/edit"
    headers = create_databricks_headers(token)
    
    # Get existing cluster configuration
    get_url = f"https://{databricks_instance}/api/2.0/clusters/get"
    response = requests.get(get_url, headers=headers, params={"cluster_id": cluster_id})
    
    if response.status_code != 200:
        logger.error(f"Failed to get cluster config: {response.text}")
        return False
    
    cluster_config = response.json()
    
    # Update with new tags
    current_tags = cluster_config.get("custom_tags", {})
    current_tags.update(custom_tags)
    cluster_config["custom_tags"] = current_tags
    
    # Apply the updated configuration
    response = requests.post(url, headers=headers, json=cluster_config)
    
    if response.status_code == 200:
        logger.info(f"Successfully tagged cluster {cluster_id}")
        return True
    else:
        logger.error(f"Failed to tag cluster: {response.text}")
        return False

def tag_sql_warehouse(warehouse_id, custom_tags, databricks_instance, token):
    """Add or update tags on a Databricks SQL warehouse."""
    # Get current warehouse configuration
    api_endpoint = f"https://{databricks_instance}/api/2.0/sql/warehouses/{warehouse_id}"
    headers = create_databricks_headers(token)
    
    response = requests.get(api_endpoint, headers=headers)
    
    if response.status_code != 200:
        logger.error(f"Error getting warehouse: {response.text}")
        return False
    
    # Get current configuration
    warehouse_config = response.json()
    
    # Update the tags
    current_tags = warehouse_config.get("tags", {})
    current_tags.update(custom_tags)
    
    # Edit warehouse with updated tags
    edit_endpoint = f"https://{databricks_instance}/api/2.0/sql/warehouses/{warehouse_id}/edit"
    
    # Prepare the required fields for the edit request
    edit_payload = {
        "id": warehouse_id,
        "name": warehouse_config["name"],
        "tags": current_tags
    }
    
    response = requests.post(
        edit_endpoint, 
        headers=headers, 
        data=json.dumps(edit_payload)
    )
    
    if response.status_code == 200:
        logger.info(f"Successfully updated tags on SQL warehouse {warehouse_id}")
        return True
    else:
        logger.error(f"Error updating tags: {response.text}")
        return False

In [0]:
# Budget policy functions

def create_budget_policy(name, custom_tags, token=None):
    """Create a budget policy for tracking spending.
    
    Args:
        name: Name of the budget policy
        custom_tags: List of tag dictionaries with key-value pairs
        token: Authentication token
    """
    url = f"{BUDGET_POLICY_BASE_URL}"
    headers = create_databricks_headers(token)
    
    # Format tags as list of dictionaries if passed as a simple dict
    if isinstance(custom_tags, dict):
        formatted_tags = [{"key": k, "value": v} for k, v in custom_tags.items()]
    else:
        formatted_tags = custom_tags
    
    # Format the request body according to documentation
    request_body = {
        "policy": {
            "policy_name": name,
            "custom_tags": formatted_tags
        },
        "request_id": str(uuid.uuid4())  # For idempotency
    }
    
    response = requests.post(url, headers=headers, json=request_body)
    
    if response.status_code == 200:
        policy_id = response.json().get("policy_id")
        logger.info(f"Successfully created budget policy '{name}' with ID: {policy_id}")
        return policy_id
    else:
        logger.error(f"Failed to create budget policy: {response.text}")
        return None
        
def list_budget_policies(token=None, page_size=100):
    """List all budget policies.
    
    Args:
        token: Authentication token
        page_size: Maximum number of policies to return (max 1000)
    """
    url = f"{BUDGET_POLICY_BASE_URL}"
    headers = create_databricks_headers(token)
    params = {"page_size": page_size}
    
    response = requests.get(url, headers=headers, params=params)
    
    if response.status_code == 200:
        policies = response.json().get("policies", [])
        logger.info(f"Successfully retrieved {len(policies)} budget policies")
        return policies
    else:
        logger.error(f"Failed to retrieve budget policies: {response.text}")
        return []
        
def get_budget_policy(policy_id, token=None):
    """Get a specific budget policy by ID.
    
    Args:
        policy_id: ID of the policy to retrieve
        token: Authentication token
    """
    url = f"{BUDGET_POLICY_BASE_URL}/{policy_id}"
    headers = create_databricks_headers(token)
    
    response = requests.get(url, headers=headers)
    
    if response.status_code == 200:
        policy = response.json()
        logger.info(f"Successfully retrieved budget policy {policy_id}")
        return policy
    else:
        logger.error(f"Failed to retrieve budget policy: {response.text}")
        return None
        
def update_budget_policy(policy_id, policy_name=None, custom_tags=None, token=None):
    """Update a budget policy.
    
    Args:
        policy_id: ID of the policy to update
        policy_name: New name for the policy (optional)
        custom_tags: New tags for the policy (optional)
        token: Authentication token
    """
    url = f"{BUDGET_POLICY_BASE_URL}/{policy_id}"
    headers = create_databricks_headers(token)
    
    # Get current policy to update specific fields
    current_policy = get_budget_policy(policy_id, token)
    if not current_policy:
        logger.error(f"Cannot update policy {policy_id}: Unable to retrieve current policy")
        return False
    
    # Update fields if specified
    if policy_name:
        current_policy["policy_name"] = policy_name
        
    if custom_tags:
        # Format tags as list of dictionaries if passed as a simple dict
        if isinstance(custom_tags, dict):
            formatted_tags = [{"key": k, "value": v} for k, v in custom_tags.items()]
            current_policy["custom_tags"] = formatted_tags
        else:
            current_policy["custom_tags"] = custom_tags
    
    response = requests.patch(url, headers=headers, json=current_policy)
    
    if response.status_code == 200:
        logger.info(f"Successfully updated budget policy {policy_id}")
        return True
    else:
        logger.error(f"Failed to update budget policy: {response.text}")
        return False
        
def delete_budget_policy(policy_id, token=None):
    """Delete a budget policy.
    
    Args:
        policy_id: ID of the policy to delete
        token: Authentication token
    """
    url = f"{BUDGET_POLICY_BASE_URL}/{policy_id}"
    headers = create_databricks_headers(token)
    
    response = requests.delete(url, headers=headers)
    
    if response.status_code == 200:
        logger.info(f"Successfully deleted budget policy {policy_id}")
        return True
    else:
        logger.error(f"Failed to delete budget policy: {response.text}")
        return False

In [0]:
# Cluster policy functions

def create_cluster_policy(category_name, instance_type_limits, max_workers, token):
    """Create a cluster policy for an operational category."""
    url = f"{CLUSTER_POLICY_URL}/create"
    headers = create_databricks_headers(token)
    
    policy_config = {
        "name": f"{category_name.lower()}_cluster_policy",
        "description": f"Cluster policy for {category_name.replace('_', ' ').title()}",
        "definition": json.dumps({
            "autotermination_minutes": {
                "type": "fixed",
                "value": 30
            },
            "custom_tags.OperationalCategory": {
                "type": "fixed",
                "value": category_name
            },
            "custom_tags.Team": {
                "type": "fixed",
                "value": DIVISION
            },
            "custom_tags.CostCenter": {
                "type": "fixed",
                "value": COST_CENTER
            },
            "node_type_id": {
                "type": "allowlist",
                "values": instance_type_limits
            },
            "num_workers": {
                "type": "range",
                "maxValue": max_workers
            }
        })
    }
    
    response = requests.post(url, headers=headers, json=policy_config)
    
    if response.status_code == 200:
        policy_id = response.json()["policy_id"]
        logger.info(f"Successfully created cluster policy for {category_name}: {policy_id}")
        return policy_id
    else:
        logger.error(f"Failed to create cluster policy: {response.text}")
        return None
        
def get_cluster_policies(token):
    """Get all cluster policies."""
    url = f"{CLUSTER_POLICY_URL}/list"
    headers = create_databricks_headers(token)
    
    response = requests.get(url, headers=headers)
    
    if response.status_code == 200:
        policies = response.json().get("policies", [])
        logger.info(f"Successfully retrieved {len(policies)} cluster policies")
        return policies
    else:
        logger.error(f"Failed to retrieve cluster policies: {response.text}")
        return []

In [0]:
# Define instance type and worker limits per category

CATEGORY_LIMITS = {
    "DEEP_SPACE_TELEMETRY": {
        "instances": ["i3.xlarge", "i3.2xlarge", "m5.xlarge", "m5.2xlarge"],
        "max_workers": 10
    },
    "PROPULSION_ANALYTICS": {
        "instances": ["r5.xlarge", "r5.2xlarge", "r5.4xlarge"],
        "max_workers": 12
    },
    "ORBITAL_MECHANICS": {
        "instances": ["m5.xlarge", "m5.2xlarge", "c5.2xlarge"],
        "max_workers": 8
    },
    "MATERIALS_SCIENCE": {
        "instances": ["c5.xlarge", "c5.2xlarge", "c5.4xlarge"],
        "max_workers": 8
    },
    "EXPLORATORY_MISSIONS": {
        "instances": ["m5.xlarge", "m5.2xlarge", "m5.4xlarge", "r5.2xlarge"],
        "max_workers": 16
    },
    "NAVIGATION_SYSTEMS": {
        "instances": ["m5.xlarge", "c5.xlarge", "c5.2xlarge"],
        "max_workers": 6
    },
    "EXOPLANET_RESEARCH": {
        "instances": ["r5.xlarge", "r5.2xlarge", "i3.2xlarge"],
        "max_workers": 10
    }
}

In [0]:
# Execution: Create budget policies and tag resources

def setup_budget_policies_and_tags():
    """Set up budget policies and tag resources for all operational categories."""
    
    # Check if ACCOUNT_ID is available
    if not ACCOUNT_ID:
        logger.error("ACCOUNT_ID is not set. Cannot create budget policies.")
        return {}
    
    results = {}
    
    for category, monthly_budget in OPERATIONAL_CATEGORIES.items():
        logger.info(f"\n{'='*20} Setting up {category} {'='*20}")
        
        # Create tags as dictionary
        tags_dict = {
            "OperationalCategory": category,
            "Team": DIVISION,
            "CostCenter": COST_CENTER,
            "Budget": str(monthly_budget),
            "Environment": "PROD"
        }
        
        # Format tags for budget policy (list of dictionaries)
        budget_policy_tags = [{"key": k, "value": v} for k, v in tags_dict.items()]
        
        # Format the budget policy name
        policy_name = f"{category.lower()}_budget_policy"
        
        # Create budget policy with the updated function
        budget_policy_id = create_budget_policy(
            name=policy_name,
            custom_tags=budget_policy_tags,
            token=TOKEN
        )
        
        # Tag existing cluster if specified
        cluster_tagged = False
        if CLUSTER_ID:
            cluster_tagged = tag_cluster(CLUSTER_ID, tags_dict, DATABRICKS_INSTANCE, TOKEN)
        
        # Tag existing warehouse if specified
        warehouse_tagged = False
        if WAREHOUSE_ID:
            warehouse_tagged = tag_sql_warehouse(WAREHOUSE_ID, tags_dict, DATABRICKS_INSTANCE, TOKEN)
        
        # Create or update cluster policy
        limits = CATEGORY_LIMITS[category]
        cluster_policy_id = None
        
        # First check if policy already exists
        existing_policies = get_cluster_policies(TOKEN)
        policy_name = f"{category.lower()}_cluster_policy"
        existing_policy = next((p for p in existing_policies if p["name"] == policy_name), None)
        
        if existing_policy:
            cluster_policy_id = existing_policy["policy_id"]
            logger.info(f"Cluster policy '{policy_name}' already exists with ID: {cluster_policy_id}")
        else:
            # Create new policy
            cluster_policy_id = create_cluster_policy(
                category, 
                limits["instances"], 
                limits["max_workers"], 
                TOKEN
            )
            
        results[category] = {
            "budget_policy_id": budget_policy_id,
            "cluster_policy_id": cluster_policy_id,
            "cluster_tagged": cluster_tagged,
            "warehouse_tagged": warehouse_tagged,
            "budget_usd": monthly_budget
        }
    
    return results

# Execute the setup
setup_results = setup_budget_policies_and_tags()

# Display results
print("\n" + "="*50)
print("Budget Policy and Tagging Setup Results")
print("="*50)
for category, result in setup_results.items():
    print(f"\n{category}:")
    print(f"  Budget: ${OPERATIONAL_CATEGORIES[category]:,}")
    print(f"  Budget Policy ID: {result.get('budget_policy_id', 'N/A')}")
    print(f"  Cluster Policy ID: {result.get('cluster_policy_id', 'N/A')}")
    print(f"  Cluster Tagged: {result.get('cluster_tagged', False)}")
    print(f"  Warehouse Tagged: {result.get('warehouse_tagged', False)}")

In [0]:
# Validate deployment

def validate_setup():
    """Validate that all budget policies, cluster policies, and tags are properly deployed."""
    print("\n" + "="*50)
    print("Validation Report")
    print("="*50)
    
    # Skip budget policy validation if ACCOUNT_ID is not set
    if not ACCOUNT_ID:
        print("WARNING: ACCOUNT_ID is not set. Skipping budget policy validation.")
    else:
        # Check for existing budget policies
        budget_policies = list_budget_policies(TOKEN)
        budget_policy_names = {p["policy_name"]: p["policy_id"] for p in budget_policies} if budget_policies else {}
        
        print("\nBudget Policies:")
        for category in OPERATIONAL_CATEGORIES.keys():
            policy_name = f"{category.lower()}_budget_policy"
            policy_exists = policy_name in budget_policy_names
            
            status = "✓" if policy_exists else "✗"
            print(f"{status} {category}: {'Policy created' if policy_exists else 'Policy missing'}")
    
    # Check for existing cluster policies
    cluster_policies = get_cluster_policies(TOKEN)
    cluster_policy_names = {p["name"]: p["policy_id"] for p in cluster_policies} if cluster_policies else {}
    
    print("\nCluster Policies:")
    for category in OPERATIONAL_CATEGORIES.keys():
        policy_name = f"{category.lower()}_cluster_policy"
        policy_exists = policy_name in cluster_policy_names
        
        status = "✓" if policy_exists else "✗"
        print(f"{status} {category}: {'Policy deployed' if policy_exists else 'Policy missing'}")
    
    # Check resource tagging if resources are specified
    if CLUSTER_ID or WAREHOUSE_ID:
        print("\nResource Tagging:")
        headers = create_databricks_headers(TOKEN)
        
        if CLUSTER_ID:
            # Check cluster tags
            get_url = f"{API_BASE_URL}/clusters/get"
            response = requests.get(get_url, headers=headers, params={"cluster_id": CLUSTER_ID})
            
            if response.status_code == 200:
                cluster = response.json()
                tags = cluster.get("custom_tags", {})
                print(f"\nCluster Tags for {CLUSTER_ID}:")
                for key, value in tags.items():
                    print(f"  {key}: {value}")
            else:
                print(f"Error checking cluster tags: {response.text}")
                
        if WAREHOUSE_ID:
            # Check warehouse tags
            url = f"{API_BASE_URL}/sql/warehouses/{WAREHOUSE_ID}"
            response = requests.get(url, headers=headers)
            
            if response.status_code == 200:
                warehouse = response.json()
                tags = warehouse.get("tags", {})
                print(f"\nSQL Warehouse Tags for {WAREHOUSE_ID}:")
                for key, value in tags.items():
                    print(f"  {key}: {value}")
            else:
                print(f"Error checking warehouse tags: {response.text}")

# Run validation
validate_setup()

## Next Steps

The budget policies have been successfully created for all operational categories. Next steps:

1. Deploy the category-specific budget monitoring notebooks
2. Configure the master budget dashboard
3. Schedule regular budget reports
4. Train team members on budget compliance

Remember to regularly review and adjust budget allocations based on actual usage patterns.