# Unity Catalog Enumeration Logic

This notebook provides utilities for enumerating Unity Catalog objects (catalogs, schemas, tables).

In [0]:
import gc
import time
import requests
import threading
from typing import Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed

# Catalog processing configuration
CATALOG_BATCH_SIZE = 2           # Process n > 0 catalogs in parallel (API calls) - DON'T MODIFY
CATALOG_WRITE_BATCH = 3          # Accumulate 100 catalogs before writing (Delta writes) - DON'T MODIFY

In [0]:
def log(msg: str):
    if VERBOSE_LOG:
        print(msg)

def banner(txt: str):
    print("\n" + "="*22 + f" {txt} " + "="*22 + "\n")

def keep_spark_alive(spark, stop_event, interval=30):
    """
    Keep Spark session alive by running simple queries periodically.
    Run this in a background thread during long operations to prevent session timeout.
    
    Args:
        spark: SparkSession instance
        stop_event: threading.Event to signal when to stop
        interval: Seconds between keep-alive pings (default: 30)
    """
    ping_count = 0
    while not stop_event.is_set():
        try:
            # Simple operation to keep session alive
            df = spark.sql("SELECT 1")
            df_count = df.count()
            ping_count += 1
            print(f"[KEEP-ALIVE] ✓ Ping #{ping_count} successful (interval: {interval}s)")
            time.sleep(interval)
        except Exception as e:
            if VERBOSE_LOG:
                print(f"[KEEP-ALIVE] ✗ Warning: {e}")
            time.sleep(interval)

In [0]:
class UnityCatalogEnumerator:
    """Unity Catalog enumeration with threading and limits."""
    
    def __init__(self, base_url: str, headers: Dict[str, str]):
        self.base_url = base_url
        self.headers = headers
    
    def list_tables_for_schema(self, catalog: str, schema: str) -> Tuple[int, int, int]:
        """Return (tables, managed, external) for catalog.schema"""
        try:
            r = requests.get(
                f"{self.base_url}/api/2.1/unity-catalog/tables",
                headers=self.headers,
                params={"catalog_name": catalog, "schema_name": schema},
                timeout=HTTP_TIMEOUT_SEC
            )
            if r.status_code == 200:
                tables = r.json().get("tables", [])
                managed = sum(1 for t in tables if t.get("table_type") == "MANAGED")
                external = sum(1 for t in tables if t.get("table_type") == "EXTERNAL")
                return len(tables), managed, external
        except Exception as e:
            log(f"[UC-ERROR] {catalog}.{schema}: {e}")
        return 0, 0, 0
    
    def list_tables_for_schema_with_data(self, catalog: str, schema: str) -> List[Dict[str, any]]:
        """Return list of table metadata for catalog.schema with pagination support"""
        try:
            all_tables = []
            next_page_token = None
            page_count = 0
            
            while True:
                params = {"catalog_name": catalog, "schema_name": schema, "max_results": 500}
                if next_page_token:
                    params["page_token"] = next_page_token
                    
                r = requests.get(
                    f"{self.base_url}/api/2.1/unity-catalog/tables",
                    headers=self.headers,
                    params=params,
                    timeout=HTTP_TIMEOUT_SEC
                )
                
                if r.status_code == 200:
                    data = r.json()
                    tables = data.get("tables", [])
                    page_count += 1
                    
                    # Add catalog and schema context to each table
                    for tbl in tables:
                        tbl["_catalog"] = catalog
                        tbl["_schema"] = schema
                    
                    all_tables.extend(tables)
                    
                    # Check for next page
                    next_page_token = data.get("next_page_token")
                    if not next_page_token:
                        break  # No more pages
                    
                    log(f"[UC-PAGINATION] {catalog}.{schema}: page {page_count} fetched {len(tables)} tables, continuing...")
                else:
                    log(f"[UC-SKIP] {catalog}.{schema}: HTTP {r.status_code}")
                    break
                    
            if page_count > 1:
                log(f"[UC-PAGINATION] {catalog}.{schema}: completed {page_count} pages, total {len(all_tables)} tables")
            return all_tables
        except Exception as e:
            log(f"[UC-ERROR] {catalog}.{schema}: {e}")
        return []
    
    def get_catalogs(self, allowlist: List[str] = None, catalog_limit: int = 0) -> List[str]:
        """Get list of catalogs to scan."""
        try:
            rr = requests.get(f"{self.base_url}/api/2.1/unity-catalog/catalogs", 
                            headers=self.headers, timeout=HTTP_TIMEOUT_SEC)
            rr.raise_for_status()
            catalogs = [c["name"] for c in rr.json().get("catalogs", [])]
            
            if allowlist:
                # Special case: if hive_metastore is in allowlist, always include it
                # (it's not returned by catalogs API but exists in legacy metastore)
                if "hive_metastore" in allowlist and "hive_metastore" not in catalogs:
                    catalogs.append("hive_metastore")
                    log("[UC] Added legacy hive_metastore to catalog list")
                
                catset = set(catalogs)
                catalogs = [c for c in allowlist if c in catset]
            
            catalogs = sorted(catalogs)
            
            if catalog_limit and catalog_limit > 0:
                catalogs = catalogs[:catalog_limit]
                
            return catalogs
        except Exception as e:
            log(f"[UC-ERROR] Failed to get catalogs: {e}")
            return []
    
    def get_schemas_for_catalog(self, catalog: str, schema_limit: int = 0) -> List[str]:
        """Get list of schemas for a catalog with pagination support."""
        try:
            all_schemas = []
            next_page_token = None
            page_count = 0
            
            while True:
                params = {"catalog_name": catalog, "max_results": 500}
                if next_page_token:
                    params["page_token"] = next_page_token
                    
                rs = requests.get(
                    f"{self.base_url}/api/2.1/unity-catalog/schemas",
                    headers=self.headers,
                    params=params,
                    timeout=HTTP_TIMEOUT_SEC
                )
                
                if rs.status_code != 200:
                    log(f"[UC] Skip catalog {catalog} (status={rs.status_code})")
                    return []
                
                data = rs.json()
                schemas = [s["name"] for s in data.get("schemas", [])]
                page_count += 1
                all_schemas.extend(schemas)
                
                # Check for next page
                next_page_token = data.get("next_page_token")
                if not next_page_token:
                    break  # No more pages
                
                log(f"[UC-PAGINATION] {catalog}: page {page_count} fetched {len(schemas)} schemas, continuing...")
            
            if page_count > 1:
                log(f"[UC-PAGINATION] {catalog}: completed {page_count} pages, total {len(all_schemas)} schemas")
            
            all_schemas = sorted(all_schemas)
            
            if schema_limit and schema_limit > 0:
                all_schemas = all_schemas[:schema_limit]
                
            return all_schemas
        except Exception as e:
            log(f"[UC-ERROR] Failed to get schemas for {catalog}: {e}")
            return []
    
    def enumerate_unity_catalog(
        self,
        enable: bool = True,
        allowlist: List[str] = None,
        catalog_limit: int = 0,
        schema_limit_per_catalog: int = 0,
        max_workers: int = 20,
        spark = None,
        write_callback = None
    ) -> Tuple[Dict[str, int], Dict[str, List[Dict[str, any]]]]:
        """
        Enumerate Unity Catalog with configurable limits and threading.
        
        Args:
            write_callback: Optional function(table_type, records) to write data per catalog
        
        Returns tuple of (counts, raw_data):
        - counts: dict with uc_catalogs, uc_schemas, uc_tables, managed_tables, external_tables
        - raw_data: dict with "schemas" and "tables" lists containing full metadata
        """
        banner("3/4 Unity Catalog Enumeration")
        out = {
            "uc_catalogs": 0,
            "uc_schemas": 0, 
            "uc_tables": 0,
            "managed_tables": 0,
            "external_tables": 0
        }
        
        # Collect actual data
        raw_data = {
            "schemas": [],
            "tables": []
        }
        
        if not enable:
            print("[UC] Skipped (UC_ENABLE=False).")
            return out, raw_data

        # Note: Spark keep-alive is now managed globally in main.ipynb
        
        try:
            allowlist = allowlist or []
            
            # Get catalogs to scan
            catalogs = self.get_catalogs(allowlist, catalog_limit)
            if not catalogs:
                print("[UC] No catalogs found or accessible.")
                return out, raw_data
                
            out["uc_catalogs"] = len(catalogs)

            print(f"[UC] Scope: allowlist={allowlist or 'ALL'}, cat_limit={catalog_limit or 'ALL'}, schema_limit={schema_limit_per_catalog or 'ALL'}")
            print(f"[UC] Catalogs to scan: {len(catalogs)} → {catalogs[:5]}{' …' if len(catalogs)>5 else ''}")

            total_schemas = total_tables = managed = external = 0
            accumulated_tables = []  # Accumulate tables across multiple catalogs
            catalogs_since_last_write = 0

            with ThreadPoolExecutor(max_workers=max_workers) as ex:
                # Process catalogs in BATCHES for speed + memory balance
                for batch_start in range(0, len(catalogs), CATALOG_BATCH_SIZE):
                    batch_end = min(batch_start + CATALOG_BATCH_SIZE, len(catalogs))
                    catalog_batch = catalogs[batch_start:batch_end]
                    batch_num = (batch_start // CATALOG_BATCH_SIZE) + 1
                    total_batches = (len(catalogs) + CATALOG_BATCH_SIZE - 1) // CATALOG_BATCH_SIZE
                    
                    print(f"\n[UC] Processing API batch {batch_num}/{total_batches}: catalogs {batch_start+1}-{batch_end}")
                    
                    # Track all futures for this batch
                    batch_futures = {}  # {future: catalog_name}
                    batch_tables_by_catalog = {}  # {catalog_name: [tables]}
                    
                    # Queue ALL schemas/tables for ALL catalogs in this batch
                    for cat in catalog_batch:
                        schemas = self.get_schemas_for_catalog(cat, schema_limit_per_catalog)
                        total_schemas += len(schemas)
                        print(f"[UC]   {cat}: {len(schemas)} schemas")
                        
                        batch_tables_by_catalog[cat] = []
                        
                        for s in schemas:
                            raw_data["schemas"].append({
                                "catalog_name": cat,
                                "schema_name": s,
                                "full_name": f"{cat}.{s}"
                            })
                            future = ex.submit(self.list_tables_for_schema_with_data, cat, s)
                            batch_futures[future] = cat
                            time.sleep(0.001)  # Reduced delay since we're batching
                    
                    # Process results for entire batch
                    for future in as_completed(batch_futures):
                        cat = batch_futures[future]
                        table_data = future.result()
                        
                        if table_data:
                            batch_tables_by_catalog[cat].extend(table_data)
                            
                            # Count stats
                            for tbl in table_data:
                                total_tables += 1
                                if tbl.get("table_type") == "MANAGED":
                                    managed += 1
                                elif tbl.get("table_type") == "EXTERNAL":
                                    external += 1
                    
                    # Accumulate tables from this batch
                    for cat in catalog_batch:
                        catalog_tables = batch_tables_by_catalog[cat]
                        if catalog_tables:
                            if write_callback:
                                # Add to accumulator instead of writing immediately
                                accumulated_tables.extend(catalog_tables)
                                catalogs_since_last_write += 1
                            else:
                                # Batch mode - collect all
                                raw_data["tables"].extend(catalog_tables)
                        
                        del catalog_tables
                    
                    del batch_tables_by_catalog
                    del batch_futures
                    
                    # Write every CATALOG_WRITE_BATCH catalogs (or at the end)
                    if write_callback and accumulated_tables:
                        should_write = (catalogs_since_last_write >= CATALOG_WRITE_BATCH) or (batch_end == len(catalogs))
                        
                        if should_write:
                            print(f"[UC] Writing {len(accumulated_tables)} tables from {catalogs_since_last_write} catalogs")
                            write_callback("databricks_table", accumulated_tables)
                            accumulated_tables.clear()
                            catalogs_since_last_write = 0
                            gc.collect()  # Garbage collect after write
                    
                    # Progress update
                    if batch_num % 5 == 0 or batch_end == len(catalogs):
                        print(f"[UC] Progress: {batch_end}/{len(catalogs)} catalogs processed, {total_tables} total tables")

            out["uc_schemas"] = total_schemas
            out["uc_tables"] = total_tables
            out["managed_tables"] = managed
            out["external_tables"] = external

            # Build scope description
            scope = []
            if allowlist: 
                scope.append(f"allow={allowlist}")
            if catalog_limit: 
                scope.append(f"cat_limit={catalog_limit}")
            if schema_limit_per_catalog: 
                scope.append(f"schema_limit={schema_limit_per_catalog}")
            desc = ", ".join(scope) if scope else "full-scan"
            
            print(f"[UC] Done ({desc}) → {out['uc_catalogs']} catalogs, {out['uc_schemas']} schemas, {out['uc_tables']} tables")
            print(f"[UC] Collected {len(raw_data['schemas'])} schema records, {len(raw_data['tables'])} table records")
            return out, raw_data
        
        except Exception as e:
            print(f"[UC-ERROR] Enumeration failed: {e}")
            import traceback
            traceback.print_exc()
            return out, raw_data


In [0]:
# Convenience function for backward compatibility
def enumerate_uc(
    base_url: str,
    headers: Dict[str, str],
    enable: bool = UC_ENABLE,
    allowlist: List[str] = None,
    catalog_limit: int = UC_CATALOG_LIMIT,
    schema_limit_per_catalog: int = UC_SCHEMA_LIMIT_PER_CATALOG,
    max_workers: int = UC_MAX_WORKERS,
    spark = None,
    write_callback = None
) -> Tuple[Dict[str, int], Dict[str, List[Dict[str, any]]]]:
    """
    Wrapper for Unity Catalog enumeration.
    
    Args:
        spark: SparkSession to keep alive during long enumeration
        write_callback: Optional function(table_type, records) to write data per catalog
    
    Returns tuple of (counts, raw_data):
    - counts: dict with catalog/schema/table counts
    - raw_data: dict with "schemas" and "tables" lists containing full metadata
    """
    enumerator = UnityCatalogEnumerator(base_url, headers)
    return enumerator.enumerate_unity_catalog(
        enable=enable,
        allowlist=allowlist or UC_CATALOG_ALLOWLIST,
        catalog_limit=catalog_limit,
        schema_limit_per_catalog=schema_limit_per_catalog,
        max_workers=max_workers,
        spark=spark,
        write_callback=write_callback
    )