In [0]:
!pip install dotenv

In [0]:
# Cell 1: Import libraries and set up environment
import os
import requests
import json
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, HTML, Markdown
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration parameters
ACCOUNT_HOST = os.getenv("ACCOUNT_HOST", "https://accounts.azuredatabricks.net")
ACCOUNT_ID = os.getenv("ACCOUNT_ID", "")
CLIENT_ID = os.getenv("CLIENT_ID", "")
CLIENT_SECRET = os.getenv("CLIENT_SECRET", "")
TOKEN = os.getenv("TOKEN")  # PAT token for workspace
DATABRICKS_INSTANCE = os.getenv("DATABRICKS_INSTANCE", "")

print("Environment variables loaded")

In [0]:
# Cell 2: Authentication for both Account and Workspace APIs

# Account-level (Budget Policy) authentication using OAuth
def get_oauth_token():
    """Get OAuth token for account-level APIs"""
    token_url = f"{ACCOUNT_HOST}/oidc/accounts/{ACCOUNT_ID}/v1/token"
    
    response = requests.post(
        token_url,
        auth=(CLIENT_ID, CLIENT_SECRET),
        headers={"Content-Type": "application/x-www-form-urlencoded"},
        data="grant_type=client_credentials&scope=all-apis"
    )
    
    if response.status_code == 200:
        token_data = response.json()
        return token_data["access_token"]
    else:
        print(f"Error getting token: {response.status_code}")
        print(response.text)
        return None

# Workspace-level (Cluster Policy) API request function
def make_workspace_api_request(method, endpoint, data=None, params=None):
    """Make a request to the Databricks Workspace API using PAT authentication"""
    url = f"https://{DATABRICKS_INSTANCE}{endpoint}"
    headers = {
        "Authorization": f"Bearer {TOKEN}",
        "Content-Type": "application/json"
    }
    
    try:
        response = requests.request(
            method=method,
            url=url,
            headers=headers,
            data=json.dumps(data) if data else None,
            params=params
        )
        
        if response.status_code in [200, 201, 202, 204]:
            return response.json() if response.content else {}
        else:
            print(f"Error: {response.status_code}")
            print(response.text)
            return None
    except Exception as e:
        print(f"Error making workspace API request: {str(e)}")
        return None

# Account-level API request function
def make_account_api_request(method, endpoint, data=None, params=None):
    """Make a request to the Databricks Account API using OAuth token"""
    url = f"{ACCOUNT_HOST}{endpoint}"
    
    # Get a fresh token
    token = get_oauth_token()
    if not token:
        print("Failed to obtain OAuth token")
        return None
    
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    
    try:
        response = requests.request(
            method=method,
            url=url,
            headers=headers,
            data=json.dumps(data) if data else None,
            params=params
        )
        
        if response.status_code in [200, 201, 202, 204]:
            return response.json() if response.content else {}
        else:
            print(f"Error: {response.status_code}")
            print(response.text)
            return None
    except Exception as e:
        print(f"Error making account API request: {str(e)}")
        return None

# Test Account and Workspace authentication
def test_authentication():
    # Test account auth
    account_endpoint = f"/api/2.1/accounts/{ACCOUNT_ID}/budget-policies"
    account_result = make_account_api_request("GET", account_endpoint)
    account_auth = account_result is not None
    
    # Test workspace auth
    workspace_endpoint = "/api/2.0/policies/clusters/list"
    workspace_result = make_workspace_api_request("GET", workspace_endpoint)
    workspace_auth = workspace_result is not None
    
    print(f"Account API Authentication: {'✅ Success' if account_auth else '❌ Failed'}")
    print(f"Workspace API Authentication: {'✅ Success' if workspace_auth else '❌ Failed'}")
    
    return account_auth and workspace_auth

# Run authentication test
auth_success = test_authentication()
if not auth_success:
    print("Please check your environment variables and credentials")

### Get policy budgets

In [0]:
# Cell 3: Budget Policy API Functions

def list_budget_policies(page_token=None, page_size=100):
    """Lists all budget policies"""
    endpoint = f"/api/2.1/accounts/{ACCOUNT_ID}/budget-policies"
    params = {"page_size": page_size}
    
    if page_token:
        params["page_token"] = page_token
    
    return make_account_api_request("GET", endpoint, params=params)

def get_budget_policy(policy_id):
    """Get a budget policy by ID"""
    endpoint = f"/api/2.1/accounts/{ACCOUNT_ID}/budget-policies/{policy_id}"
    return make_account_api_request("GET", endpoint)

def create_budget_policy(policy_name, custom_tags=None):
    """Create a new budget policy with specified name and tags"""
    endpoint = f"/api/2.1/accounts/{ACCOUNT_ID}/budget-policies"
    
    data = {
        "policy": {
            "policy_name": policy_name,
            "custom_tags": custom_tags or []
        }
    }
    
    return make_account_api_request("POST", endpoint, data=data)

def update_budget_policy(policy_id, policy_name=None, custom_tags=None):
    """Update an existing budget policy"""
    endpoint = f"/api/2.1/accounts/{ACCOUNT_ID}/budget-policies/{policy_id}"
    
    # Start with empty data and add only what's provided
    data = {}
    
    if policy_name:
        data["policy_name"] = policy_name
    
    if custom_tags is not None:
        data["custom_tags"] = custom_tags
    
    return make_account_api_request("PATCH", endpoint, data=data)

def delete_budget_policy(policy_id):
    """Delete a budget policy"""
    endpoint = f"/api/2.1/accounts/{ACCOUNT_ID}/budget-policies/{policy_id}"
    return make_account_api_request("DELETE", endpoint)

### Cluster policy API

In [0]:
# Cell 4: Cluster Policy API Functions

def list_cluster_policies():
    """List all cluster policies"""
    endpoint = "/api/2.0/policies/clusters/list"
    return make_workspace_api_request("GET", endpoint)

def get_cluster_policy(policy_id):
    """Get a cluster policy by ID"""
    endpoint = "/api/2.0/policies/clusters/get"
    params = {"policy_id": policy_id}
    return make_workspace_api_request("GET", endpoint, params=params)

def create_cluster_policy(name, definition, description=None, max_clusters_per_user=None):
    """Create a new cluster policy"""
    endpoint = "/api/2.0/policies/clusters/create"
    
    data = {
        "name": name,
        "definition": definition
    }
    
    if description:
        data["description"] = description
    
    if max_clusters_per_user is not None:
        data["max_clusters_per_user"] = max_clusters_per_user
    
    return make_workspace_api_request("POST", endpoint, data=data)

def update_cluster_policy(policy_id, name=None, definition=None, 
                          description=None, max_clusters_per_user=None):
    """Update an existing cluster policy"""
    endpoint = "/api/2.0/policies/clusters/edit"
    
    data = {"policy_id": policy_id}
    
    if name:
        data["name"] = name
    
    if definition:
        data["definition"] = definition
    
    if description:
        data["description"] = description
    
    if max_clusters_per_user is not None:
        data["max_clusters_per_user"] = max_clusters_per_user
    
    return make_workspace_api_request("POST", endpoint, data=data)

def delete_cluster_policy(policy_id):
    """Delete a cluster policy"""
    endpoint = "/api/2.0/policies/clusters/delete"
    data = {"policy_id": policy_id}
    return make_workspace_api_request("POST", endpoint, data=data)

def set_cluster_policy_permissions(policy_id, access_control_list):
    """Set permissions on a cluster policy"""
    endpoint = f"/api/2.0/permissions/cluster-policies/{policy_id}"
    data = {"access_control_list": access_control_list}
    return make_workspace_api_request("PUT", endpoint, data=data)

### Tag Sync

In [0]:
# Cell 5: Helper Functions for Tag Synchronization

def extract_tags_from_budget_policy(budget_policy):
    """Extract tags from a budget policy in a format suitable for cluster policies"""
    if not budget_policy:
        return {}
    
    tags = {}
    custom_tags = budget_policy.get("custom_tags", [])
    
    for tag in custom_tags:
        key = tag.get("key")
        value = tag.get("value")
        if key and value:
            tags[f"custom_tags.{key}"] = {
                "type": "fixed",
                "value": value
            }
    
    return tags

def merge_policy_definition_with_tags(base_definition, tags):
    """Merge a base cluster policy definition with tags from a budget policy"""
    if not base_definition:
        # Create empty definition if none provided
        definition = {}
    else:
        # Parse the definition string to a dict if it's a string
        if isinstance(base_definition, str):
            try:
                definition = json.loads(base_definition)
            except json.JSONDecodeError:
                print("Error: Invalid JSON in policy definition")
                return None
        else:
            definition = base_definition.copy()
    
    # Add or update tags
    for tag_key, tag_config in tags.items():
        definition[tag_key] = tag_config
    
    return definition

def format_policy_display(policy, policy_type="budget"):
    """Format a policy for better display"""
    if not policy:
        return "No policy data available"
    
    output = []
    
    if policy_type == "budget":
        output.append(f"Policy ID: {policy.get('policy_id', 'N/A')}")
        output.append(f"Policy Name: {policy.get('policy_name', 'N/A')}")
        
        tags = policy.get('custom_tags', [])
        if tags:
            output.append("Custom Tags:")
            for tag in tags:
                output.append(f"  - {tag.get('key', 'N/A')}: {tag.get('value', 'N/A')}")
    else:  # cluster policy
        output.append(f"Policy ID: {policy.get('policy_id', 'N/A')}")
        output.append(f"Policy Name: {policy.get('name', 'N/A')}")
        output.append(f"Created by: {policy.get('creator_user_name', 'N/A')}")
        
        if policy.get('description'):
            output.append(f"Description: {policy.get('description')}")
            
        if policy.get('definition'):
            # Try to format the definition if it's a string
            if isinstance(policy.get('definition'), str):
                try:
                    definition = json.loads(policy.get('definition'))
                    # Extract just the tags for display
                    tags = {k: v for k, v in definition.items() if k.startswith('custom_tags')}
                    if tags:
                        output.append("Custom Tags in Definition:")
                        for tag_key, tag_value in tags.items():
                            key = tag_key.replace('custom_tags.', '')
                            value = tag_value.get('value', 'N/A') if isinstance(tag_value, dict) else str(tag_value)
                            output.append(f"  - {key}: {value}")
                except:
                    pass
    
    return "\n".join(output)

### UI Elements

In [0]:
# Cell 6: Interactive UI for Budget and Cluster Policy Management

# Template definitions for common roles
TEMPLATE_DEFINITIONS = {
    "Data Science": {
        "spark_version": {
            "type": "fixed", 
            "value": "11.3.x-cpu-ml-scala2.12"
        },
        "node_type_id": {
            "type": "allowlist",
            "values": ["Standard_DS3_v2", "Standard_DS4_v2", "Standard_NC4as_T4_v3"],
            "defaultValue": "Standard_DS3_v2"
        },
        "autotermination_minutes": {
            "type": "fixed",
            "value": 120
        }
    },
    "Data Engineering": {
        "spark_version": {
            "type": "fixed", 
            "value": "11.3.x-scala2.12"
        },
        "node_type_id": {
            "type": "allowlist",
            "values": ["Standard_DS3_v2", "Standard_DS4_v2"],
            "defaultValue": "Standard_DS3_v2"
        },
        "autotermination_minutes": {
            "type": "fixed",
            "value": 60
        }
    },
    "Analytics": {
        "spark_version": {
            "type": "fixed", 
            "value": "11.3.x-scala2.12"
        },
        "node_type_id": {
            "type": "fixed",
            "value": "Standard_DS3_v2"
        },
        "num_workers": {
            "type": "fixed",
            "value": 0,
            "hidden": True
        },
        "autotermination_minutes": {
            "type": "fixed",
            "value": 60
        }
    }
}

# Create widgets
title = widgets.HTML("<h1>Budget and Cluster Policy Integration</h1>")

# Budget policy widgets
budget_section = widgets.HTML("<h2>Budget Policy Selection</h2>")
budget_dropdown = widgets.Dropdown(description="Budget Policy:", options=[("Loading...", "loading")])
budget_details = widgets.Output()

# Create new budget policy section
new_budget_section = widgets.HTML("<h3>Create New Budget Policy</h3>")
new_budget_name = widgets.Text(description="Name:", placeholder="Enter budget policy name")

# Tag widgets for new budget policies
tag_rows = []
for i in range(3):  # Support up to 3 tag pairs
    tag_key = widgets.Text(description=f"Tag Key {i+1}:", placeholder="Enter key")
    tag_value = widgets.Text(description=f"Tag Value {i+1}:", placeholder="Enter value")
    tag_rows.append(widgets.HBox([tag_key, tag_value]))

create_budget_button = widgets.Button(description="Create Budget Policy", button_style="primary")
budget_status = widgets.Output()

# Cluster policy widgets
cluster_section = widgets.HTML("<h2>Cluster Policy Management</h2>")
cluster_dropdown = widgets.Dropdown(description="Policy Template:", options=list(TEMPLATE_DEFINITIONS.keys()))
cluster_name = widgets.Text(description="Policy Name:", placeholder="Enter cluster policy name")
max_clusters = widgets.IntText(description="Max clusters:", value=5, min=1)
cluster_description = widgets.Textarea(description="Description:", placeholder="Enter description")

create_cluster_button = widgets.Button(description="Create Cluster Policy with Budget Tags", button_style="success")
cluster_status = widgets.Output()

# Functions to update the UI
def update_budget_policies_dropdown():
    """Fetch and update the budget policies dropdown"""
    policies = list_budget_policies()
    
    if policies and "policies" in policies:
        options = [("Select a policy", "")]
        for policy in policies["policies"]:
            policy_id = policy.get("policy_id")
            policy_name = policy.get("policy_name", "Unnamed")
            options.append((policy_name, policy_id))
        
        budget_dropdown.options = options
        budget_dropdown.value = ""
    else:
        budget_dropdown.options = [("No policies found", "")]
        budget_dropdown.value = ""

def on_budget_selection_change(change):
    """Handle budget policy selection change"""
    with budget_details:
        budget_details.clear_output()
        
        if change["new"] and change["new"] != "loading" and change["new"] != "":
            policy = get_budget_policy(change["new"])
            if policy:
                print(format_policy_display(policy, "budget"))
            else:
                print("Failed to load policy details")
        else:
            print("No policy selected")

def on_create_budget_click(b):
    """Handle budget policy creation"""
    with budget_status:
        budget_status.clear_output()
        
        name = new_budget_name.value.strip()
        if not name:
            print("Error: Policy name is required")
            return
        
        # Collect tags
        tags = []
        for row in tag_rows:
            key = row.children[0].value.strip()
            value = row.children[1].value.strip()
            if key and value:
                tags.append({"key": key, "value": value})
        
        print(f"Creating budget policy '{name}' with {len(tags)} tags...")
        result = create_budget_policy(name, tags)
        
        if result:
            print(f"✅ Budget policy created successfully!")
            # Clear inputs
            new_budget_name.value = ""
            for row in tag_rows:
                row.children[0].value = ""
                row.children[1].value = ""
            # Update dropdown
            update_budget_policies_dropdown()
        else:
            print("❌ Failed to create budget policy")

def on_create_cluster_click(b):
    """Handle cluster policy creation with budget tags"""
    with cluster_status:
        cluster_status.clear_output()
        
        # Validate inputs
        policy_name = cluster_name.value.strip()
        if not policy_name:
            print("Error: Cluster policy name is required")
            return
        
        # Get selected budget policy
        if not budget_dropdown.value:
            print("Error: Please select a budget policy to use its tags")
            return
        
        budget_policy = get_budget_policy(budget_dropdown.value)
        if not budget_policy:
            print("Error: Failed to load selected budget policy")
            return
        
        # Get template definition and merge with budget tags
        template_name = cluster_dropdown.value
        base_definition = TEMPLATE_DEFINITIONS.get(template_name, {})
        
        # Extract and format tags from budget policy
        budget_tags = extract_tags_from_budget_policy(budget_policy)
        
        # Merge template with tags
        merged_definition = merge_policy_definition_with_tags(base_definition, budget_tags)
        if not merged_definition:
            print("Error: Failed to merge policy definition with tags")
            return
        
        # Create the cluster policy
        print(f"Creating cluster policy '{policy_name}' with {len(budget_tags)} budget tags...")
        
        result = create_cluster_policy(
            name=policy_name,
            definition=json.dumps(merged_definition),
            description=cluster_description.value if cluster_description.value else None,
            max_clusters_per_user=max_clusters.value
        )
        
        if result:
            print(f"✅ Cluster policy created successfully with ID: {result.get('policy_id', 'N/A')}")
            
            # Clear inputs
            cluster_name.value = ""
            cluster_description.value = ""
        else:
            print("❌ Failed to create cluster policy")

# Connect event handlers
budget_dropdown.observe(on_budget_selection_change, names="value")
create_budget_button.on_click(on_create_budget_click)
create_cluster_button.on_click(on_create_cluster_click)

# Initialize the budget policies dropdown
update_budget_policies_dropdown()

# Layout the UI components
display(title)
display(budget_section)
display(budget_dropdown)
display(budget_details)

display(new_budget_section)
display(new_budget_name)
for row in tag_rows:
    display(row)
display(create_budget_button)
display(budget_status)

display(cluster_section)
display(cluster_name)
display(cluster_dropdown)
display(max_clusters)
display(cluster_description)
display(create_cluster_button)
display(cluster_status)

### View and manage clusters

In [0]:
# Cell 7: Manage Existing Cluster Policies

manage_title = widgets.HTML("<h2>Manage Existing Cluster Policies</h2>")

# Cluster policy listing and management
refresh_button = widgets.Button(description="Refresh Policies", button_style="info")
cluster_policies_output = widgets.Output()

# Update cluster policy section
update_section = widgets.HTML("<h3>Update Cluster Policy</h3>")
update_policy_dropdown = widgets.Dropdown(description="Select Policy:", options=[("Loading...", "loading")])
update_tags_button = widgets.Button(description="Update with Budget Tags", button_style="warning")
update_output = widgets.Output()

def list_formatted_cluster_policies():
    """Get and format cluster policies for display"""
    policies_response = list_cluster_policies()
    
    if not policies_response or "policies" not in policies_response:
        return "No cluster policies found or error retrieving policies."
    
    policies = policies_response.get('policies', [])
    
    output = [f"Total policies: {len(policies)}"]
    
    for i, policy in enumerate(policies, 1):
        policy_str = format_policy_display(policy, "cluster")
        output.append(f"\n{i}. {policy_str}")
        output.append("-" * 50)
    
    return "\n".join(output)

def update_cluster_policies_dropdown():
    """Fetch and update the cluster policies dropdown"""
    policies_response = list_cluster_policies()
    
    if policies_response and "policies" in policies_response:
        options = [("Select a policy", "")]
        for policy in policies_response.get('policies', []):
            policy_id = policy.get("policy_id")
            policy_name = policy.get("name", "Unnamed")
            options.append((policy_name, policy_id))
        
        update_policy_dropdown.options = options
        update_policy_dropdown.value = ""
    else:
        update_policy_dropdown.options = [("No policies found", "")]
        update_policy_dropdown.value = ""

def on_refresh_click(b):
    """Refresh the cluster policies display"""
    with cluster_policies_output:
        cluster_policies_output.clear_output()
        print("Refreshing cluster policies...")
        policies_text = list_formatted_cluster_policies()
        print(policies_text)
        
    # Also update the dropdown
    update_cluster_policies_dropdown()

def on_update_tags_click(b):
    """Update an existing cluster policy with tags from the selected budget policy"""
    with update_output:
        update_output.clear_output()
        
        # Validate selections
        if not update_policy_dropdown.value or update_policy_dropdown.value == "loading":
            print("Error: Please select a cluster policy to update")
            return
        
        if not budget_dropdown.value:
            print("Error: Please select a budget policy to use its tags")
            return
        
        # Get the cluster policy
        cluster_policy = get_cluster_policy(update_policy_dropdown.value)
        if not cluster_policy:
            print("Error: Failed to load selected cluster policy")
            return
        
        # Get the budget policy
        budget_policy = get_budget_policy(budget_dropdown.value)
        if not budget_policy:
            print("Error: Failed to load selected budget policy")
            return
        
        # Extract current definition
        current_definition = cluster_policy.get("definition")
        if isinstance(current_definition, str):
            try:
                current_definition = json.loads(current_definition)
            except:
                print("Error: Failed to parse current policy definition")
                return
        
        # Extract tags from budget policy
        budget_tags = extract_tags_from_budget_policy(budget_policy)
        
        # Merge definition with new tags
        new_definition = merge_policy_definition_with_tags(current_definition, budget_tags)
        if not new_definition:
            print("Error: Failed to merge policy definition with tags")
            return
        
        # Update the cluster policy - FIXED: Include the existing name
        print(f"Updating cluster policy '{cluster_policy.get('name')}' with tags from budget policy '{budget_policy.get('policy_name')}'...")
        
        result = update_cluster_policy(
            policy_id=update_policy_dropdown.value,
            name=cluster_policy.get('name'),  # Include existing name
            definition=json.dumps(new_definition)
        )
        
        if result is not None:
            print(f"✅ Cluster policy updated successfully!")
            
            # Refresh the policies display
            on_refresh_click(None)
        else:
            print("❌ Failed to update cluster policy")

# Connect event handlers
refresh_button.on_click(on_refresh_click)
update_tags_button.on_click(on_update_tags_click)

# Initialize the cluster policies dropdown
update_cluster_policies_dropdown()

# Display the UI components
display(manage_title)
display(refresh_button)
display(cluster_policies_output)

display(update_section)
display(update_policy_dropdown)
display(budget_dropdown)  # Re-use the same budget policy dropdown
display(update_tags_button)
display(update_output)

# Initial load of cluster policies
on_refresh_click(None)

### Dashboard to analyze policy

In [0]:
# Cell 8: Policy Cost Analysis Dashboard

dashboard_title = widgets.HTML("<h2>Policy Cost Analysis Dashboard</h2>")

# Time range selection
days_slider = widgets.IntSlider(
    value=30,
    min=1,
    max=90,
    step=1,
    description='Days to analyze:',
    style={'description_width': 'initial'}
)

# Policy filter dropdowns
analysis_budget_dropdown = widgets.Dropdown(
    description="Budget Filter:",
    options=[("All Budget Policies", "all")],
    value="all"
)

# Analysis options
analysis_type = widgets.RadioButtons(
    options=['Overall Usage by Tag', 'Tag-Based Cost Analysis', 'Tag Compliance Report'],
    value='Overall Usage by Tag',
    description='Analysis:',
    style={'description_width': 'initial'}
)

run_analysis_button = widgets.Button(description="Run Analysis", button_style="primary")
analysis_output = widgets.Output()

def get_tag_usage_query(days, tag_key=None, tag_value=None):
    """Generate a SQL query for tag usage analysis"""
    base_query = f"""
    SELECT 
      custom_tags as tags,
      SUM(usage_quantity) as total_dbu,
      billing_origin_product
    FROM 
      system.billing.usage
    WHERE 
      usage_date >= CURRENT_DATE() - INTERVAL {days} DAY
    """
    
    if tag_key and tag_value:
        base_query += f"\n  AND custom_tags['{tag_key}'] = '{tag_value}'"
    
    base_query += """
    GROUP BY 
      custom_tags, billing_origin_product
    ORDER BY 
      total_dbu DESC
    """
    
    return base_query

def get_tag_cost_query(days, tag_key=None):
    """Generate a SQL query for tag cost analysis"""
    base_query = f"""
    SELECT 
      custom_tags['{tag_key}'] as tag_value,
      SUM(usage.usage_quantity * list_prices.pricing.default) as estimated_cost
    FROM 
      system.billing.usage usage
    JOIN 
      system.billing.list_prices list_prices ON
        usage.sku_name = list_prices.sku_name AND
        usage.usage_start_time >= list_prices.price_start_time AND
        (usage.usage_end_time <= list_prices.price_end_time OR list_prices.price_end_time IS NULL)
    WHERE 
      usage.usage_date >= CURRENT_DATE() - INTERVAL {days} DAY
      AND custom_tags['{tag_key}'] IS NOT NULL
    GROUP BY 
      custom_tags['{tag_key}']
    ORDER BY 
      estimated_cost DESC
    """
    
    return base_query

def get_tag_compliance_query(days):
    """Generate a SQL query for tag compliance analysis"""
    base_query = f"""
    SELECT 
      CASE WHEN custom_tags IS NULL OR custom_tags = MAP() THEN 'Untagged' ELSE 'Tagged' END as tagged_status,
      COUNT(*) as resource_count,
      SUM(usage_quantity) as total_dbu
    FROM 
      system.billing.usage
    WHERE 
      usage_date >= CURRENT_DATE() - INTERVAL {days} DAY
    GROUP BY 
      CASE WHEN custom_tags IS NULL OR custom_tags = MAP() THEN 'Untagged' ELSE 'Tagged' END
    """
    
    return base_query

def update_analysis_dropdowns():
    """Update the budget policy dropdown for analysis"""
    policies = list_budget_policies()
    
    if policies and "policies" in policies:
        options = [("All Budget Policies", "all")]
        
        # Extract all tag keys from all policies
        all_tags = set()
        
        for policy in policies["policies"]:
            policy_id = policy.get("policy_id")
            policy_name = policy.get("policy_name", "Unnamed")
            options.append((policy_name, policy_id))
            
            # Extract tag keys
            tags = policy.get("custom_tags", [])
            for tag in tags:
                if "key" in tag:
                    all_tags.add(tag["key"])
        
        analysis_budget_dropdown.options = options
        
    else:
        analysis_budget_dropdown.options = [("No policies found", "all")]

def on_run_analysis_click(b):
    """Handle analysis button click"""
    with analysis_output:
        analysis_output.clear_output()
        
        days = days_slider.value
        selected_analysis = analysis_type.value
        budget_policy_id = analysis_budget_dropdown.value
        
        print(f"Running {selected_analysis} for the last {days} days...")
        
        # Get tags from selected budget policy
        tag_key = None
        tag_value = None
        
        if budget_policy_id != "all":
            budget_policy = get_budget_policy(budget_policy_id)
            if budget_policy and budget_policy.get("custom_tags"):
                # Use the first tag for analysis
                tag = budget_policy.get("custom_tags")[0]
                tag_key = tag.get("key")
                tag_value = tag.get("value")
                print(f"Using tag filter: {tag_key}={tag_value}")
        
        try:
            # Execute queries using Spark SQL
            if selected_analysis == 'Overall Usage by Tag':
                query = get_tag_usage_query(days, tag_key, tag_value)
                print(f"Executing query:")
                print(query)
                
                # Run the query through Spark
                df = spark.sql(query)
                
                # Convert to pandas for visualization
                pdf = df.toPandas()
                
                if not pdf.empty:
                    # Display results as a table
                    display(HTML(pdf.to_html(index=False)))
                    
                    # Create a bar chart
                    plt.figure(figsize=(10, 6))
                    
                    # If there are many results, limit to top 10
                    if len(pdf) > 10:
                        plot_data = pdf.head(10)
                        title_suffix = " (Top 10)"
                    else:
                        plot_data = pdf
                        title_suffix = ""
                    
                    # Create a simple identifier for each row
                    x_labels = [f"Row {i+1}" for i in range(len(plot_data))]
                    
                    plt.bar(x_labels, plot_data['total_dbu'])
                    plt.title(f'DBU Usage by Tag{title_suffix}')
                    plt.xlabel('Tags')
                    plt.ylabel('Total DBUs')
                    plt.xticks(rotation=45, ha='right')
                    plt.tight_layout()
                    plt.show()
                else:
                    print("No data found for the specified criteria")
                
            elif selected_analysis == 'Tag-Based Cost Analysis':
                if not tag_key:
                    print("Error: Cannot run tag-based cost analysis without a tag key")
                    print("Please select a specific budget policy")
                    return
                    
                query = get_tag_cost_query(days, tag_key)
                print(f"Executing query:")
                print(query)
                
                # Run the query
                df = spark.sql(query)
                pdf = df.toPandas()
                
                if not pdf.empty:
                    # Display results as a table
                    display(HTML(pdf.to_html(index=False)))
                    
                    # Create a bar chart
                    plt.figure(figsize=(10, 6))
                    plt.bar(pdf['tag_value'], pdf['estimated_cost'])
                    plt.title(f'Estimated Cost by Tag Value for "{tag_key}"')
                    plt.xlabel('Tag Value')
                    plt.ylabel('Estimated Cost ($)')
                    plt.xticks(rotation=45, ha='right')
                    plt.tight_layout()
                    plt.show()
                else:
                    print("No data found for the specified criteria")
                
            elif selected_analysis == 'Tag Compliance Report':
                query = get_tag_compliance_query(days)
                print(f"Executing query:")
                print(query)
                
                # Run the query
                df = spark.sql(query)
                pdf = df.toPandas()
                
                if not pdf.empty:
                    # Display results as a table
                    display(HTML(pdf.to_html(index=False)))
                    
                    # Create a pie chart
                    plt.figure(figsize=(8, 8))
                    plt.pie(pdf['total_dbu'], labels=pdf['tagged_status'], autopct='%1.1f%%')
                    plt.title('Tagged vs Untagged Resources (by DBU)')
                    plt.show()
                else:
                    print("No data found for the specified criteria")
            
        except Exception as e:
            print(f"Error running analysis: {str(e)}")
            print("This might be because you don't have access to the system tables or they're not enabled.")
            print("Make sure you have proper permissions to access system.billing.usage")

# Connect event handlers
run_analysis_button.on_click(on_run_analysis_click)

# Initialize the analysis dropdowns
update_analysis_dropdowns()

# Display the UI components
display(dashboard_title)
display(widgets.HBox([days_slider, analysis_budget_dropdown]))
display(analysis_type)
display(run_analysis_button)
display(analysis_output)