In [0]:
%pip install aiohttp
dbutils.library.restartPython()

# Databricks Workspace Assessment - Main Execution

This notebook orchestrates the full workspace assessment process.

In [0]:
%run "./config"

In [0]:
%run "./endpoints"

In [0]:
%run "./data_processing"

In [0]:
%run "./unity_catalog"

In [0]:
%run "./api_client"

In [0]:
import os, sys
import time
import asyncio
import nest_asyncio
from datetime import datetime, timezone

In [0]:
def main():
    """Main execution function for the Databricks workspace assessment."""
    
    # Initialize timing and async support
    start_time = time.time()
    start_ts = datetime.now(timezone.utc).isoformat(timespec="seconds")
    nest_asyncio.apply()
    
    print("="*80)
    print("[🔐 AUTH] Initializing Databricks authentication...")
    
    # Get Databricks connection details
    # These variables (spark, dbutils) are available in Databricks notebooks
    try:
        workspace_url = str(spark.conf.get("spark.databricks.workspaceUrl"))
        token = str(dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get())
        
        print(f"[Init] Connected to workspace: {workspace_url}")
        print("="*80, "\n")
        
    except Exception as e:
        print(f"[ERROR] Failed to get Databricks credentials: {e}")
        print("Make sure this script is running in a Databricks notebook environment.")
        return
    
    print("[▶️ RUN] Starting full workspace assessment...")
    print(f"[CONFIG] Streaming writes: {'ENABLED' if ENABLE_STREAMING_WRITES else 'DISABLED'}")
    
    # Initialize components
    data_processor = DataProcessor(spark, workspace_url, start_ts, TARGET_CATALOG, TARGET_SCHEMA)
    
    # Initialize API client with optional streaming writes support
    if ENABLE_STREAMING_WRITES:
        # Ensure UC sink exists before starting streaming writes
        # data_processor.spark.sql(f"CREATE CATALOG IF NOT EXISTS `{TARGET_CATALOG}`")
        # data_processor.spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{TARGET_CATALOG}`.`{TARGET_SCHEMA}`")
        api_client = DatabricksAPIClient(workspace_url, token, data_processor, enable_streaming_writes=True)
        print("🚀 Streaming mode: Raw data will be written to UC immediately after each API call")
    else:
        api_client = DatabricksAPIClient(workspace_url, token)
        print("📦 Batch mode: Raw data will be collected in memory and written at the end")
    
    try:
        # Step 1: Collect REST API data asynchronously
        raw_data, rest_counts = asyncio.run(api_client.collect_all_endpoints())
        
        # Add DBFS mounts (requires dbutils)
        mount_data, mount_count = api_client.collect_dbfs_mounts(dbutils)
        rest_counts["dbfs_mount_points"] = mount_count
        raw_data["dbfs_mount_points"] = mount_data
        
        # Write DBFS mounts immediately if streaming is enabled
        if ENABLE_STREAMING_WRITES and mount_data:
            data_processor.write_single_raw_table("dbfs_mount_points", mount_data)
        
        # Step 2: Unity Catalog enumeration with batch writing
        base_url = f"https://{workspace_url}"
        headers = {"Authorization": f"Bearer {token}"}
        
        # Define write callback for batch writing UC tables
        def write_uc_batch(table_type, records):
            """Callback to write UC table batches during enumeration."""
            if ENABLE_STREAMING_WRITES:
                data_processor.write_single_raw_table(table_type, records)
        
        uc_counts, uc_raw_data = enumerate_uc(
            base_url=base_url,
            headers=headers,
            enable=UC_ENABLE,
            allowlist=UC_CATALOG_ALLOWLIST,
            catalog_limit=UC_CATALOG_LIMIT,
            schema_limit_per_catalog=UC_SCHEMA_LIMIT_PER_CATALOG,
            max_workers=UC_MAX_WORKERS,
            write_callback=write_uc_batch if ENABLE_STREAMING_WRITES else None,
            batch_size=UC_TABLE_BATCH_SIZE
        )
        
        # Write UC schema data (tables are already written in batches if streaming)
        if uc_raw_data.get("schemas"):
            raw_data["databricks_schema"] = uc_raw_data["schemas"]
            if ENABLE_STREAMING_WRITES:
                data_processor.write_single_raw_table("databricks_schema", uc_raw_data["schemas"])
        
        # Only write tables in batch mode (streaming mode already wrote them)
        if uc_raw_data.get("tables") and not ENABLE_STREAMING_WRITES:
            raw_data["databricks_table"] = uc_raw_data["tables"]
        
        # Step 3: Process and write data
        if ENABLE_STREAMING_WRITES:
            # Raw tables already written during collection, just write summary
            print("\n" + "="*22 + " 2/4 RAW Data (Already Streamed) " + "="*22 + "\n")
            print("[STREAM] Raw tables already written during API collection")
            
            # Ensure UC sink exists and write summary
            # ensure_uc_sink(TARGET_CATALOG, TARGET_SCHEMA, spark) # Predefine Your Catalog and Schema Before Hand
            
            # Combine all counts for summary
            all_counts = rest_counts.copy()
            all_counts.update(uc_counts)
            
            summary_df = build_and_write_summary(all_counts, TARGET_CATALOG, TARGET_SCHEMA, spark)
        else:
            # Traditional batch mode - write everything at the end
            summary_df = data_processor.process_and_write_all(
                raw_data=raw_data,
                uc_counts=uc_counts,
                catalog=TARGET_CATALOG,
                schema=TARGET_SCHEMA
            )
        
        # Step 4: Display results
        display(summary_df.orderBy("Category", "Object"))
        
        # Final summary
        runtime_min = round((time.time() - start_time) / 60, 2)
        total_objects = len(raw_data)
        total_records = sum(len(records) for records in raw_data.values())
        
        print(f"\n[✅ DONE] Completed in {runtime_min} min.")
        print(f"[STATS] Collected {total_objects} object types with {total_records:,} total records")
        print(f"[STATS] UC: {uc_counts['uc_catalogs']} catalogs, {uc_counts['uc_schemas']} schemas, {uc_counts['uc_tables']} tables")
        
    except Exception as e:
        print(f"[ERROR] Assessment failed: {e}")
        import traceback
        traceback.print_exc()


# Run the main assessment
main()

In [0]:
# catalog_name = "users"
# schema_name = "robert_altmiller"

# # Get all tables in the catalog and schema
# tables_df = spark.sql(f"SHOW TABLES IN {catalog_name}.{schema_name}")
# tables = [row.tableName for row in tables_df.collect()]

# print(f"Found {len(tables)} tables in {catalog_name}.{schema_name}")

# # Drop all tables
# for t in tables:
#     fqtn = f"{catalog_name}.{schema_name}.{t}"
#     print(f"Dropping {fqtn} ...")
#     spark.sql(f"DROP TABLE IF EXISTS {fqtn}")


In [0]:
# # test_jobs_pagination.py
# import asyncio
# import aiohttp
# import os

# async def test_jobs_api_real():
#     """
#     Real integration test for Jobs API with pagination.
#     Tests actual API calls to verify pagination works correctly.
#     """
    
#     # Get credentials from environment or Databricks context
#     workspace_url = os.getenv("DATABRICKS_HOST", "e2-demo-field-eng.cloud.databricks.com")
#     token = os.getenv("DATABRICKS_TOKEN", str(dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()))
    
#     if not token:
#         # Try to get from dbutils if running in Databricks
#         try:
#             token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
#         except:
#             print("❌ No token found. Set DATABRICKS_TOKEN environment variable.")
#             return
    
#     base_url = f"https://{workspace_url}"
#     headers = {"Authorization": f"Bearer {token}"}
    
#     print(f"🔗 Testing Jobs API: {base_url}/api/2.2/jobs/list")
#     print("="*60)
    
#     all_jobs = []
#     page_num = 0
#     params = {"limit": 100}  # Use smaller limit to force pagination
    
#     async with aiohttp.ClientSession() as session:
#         while True:
#             page_num += 1
#             print(f"\n📄 Fetching page {page_num}...")
            
#             async with session.get(
#                 f"{base_url}/api/2.2/jobs/list",
#                 headers=headers,
#                 params=params,
#                 timeout=aiohttp.ClientTimeout(total=60)
#             ) as resp:
#                 if resp.status != 200:
#                     print(f"❌ Error: {resp.status} - {await resp.text()}")
#                     break
                
#                 data = await resp.json()
#                 jobs = data.get("jobs", [])
#                 next_token = data.get("next_page_token")
                
#                 print(f"   ✅ Got {len(jobs)} jobs")
#                 if next_token:
#                     print(f"   🔗 Next token: {next_token[:30]}...")
                
#                 all_jobs.extend(jobs)
                
#                 # Continue pagination
#                 if next_token:
#                     params["page_token"] = next_token
#                 else:
#                     print(f"\n   ℹ️  No more pages")
#                     break
                
#                 # Safety limit
#                 # if page_num >= 10:
#                 #     print(f"\n   ⚠️  Stopping at page {page_num} (safety limit)")
#                 #     break
    
#     print("\n" + "="*60)
#     print(f"✅ Test complete!")
#     print(f"   Total pages fetched: {page_num}")
#     print(f"   Total jobs collected: {len(all_jobs)}")
    
#     if all_jobs:
#         print(f"\n📋 Sample jobs:")
#         for job in all_jobs[:3]:
#             print(f"   - {job.get('job_id')}: {job.get('settings', {}).get('name', 'N/A')}")
    
#     return all_jobs


# jobs = asyncio.run(test_jobs_api_real())