# Unity Catalog Enumeration Logic

This notebook provides utilities for enumerating Unity Catalog objects (catalogs, schemas, tables).

In [0]:
import gc
import time
import requests
import threading
from typing import Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed

# Spark keep-alive configuration
SPARK_KEEPALIVE_INTERVAL = 600  # seconds

In [0]:
def log(msg: str):
    if VERBOSE_LOG:
        print(msg)

def banner(txt: str):
    print("\n" + "="*22 + f" {txt} " + "="*22 + "\n")

def keep_spark_alive(spark, stop_event, interval=30):
    """
    Keep Spark session alive by running simple queries periodically.
    Run this in a background thread during long operations to prevent session timeout.
    
    Args:
        spark: SparkSession instance
        stop_event: threading.Event to signal when to stop
        interval: Seconds between keep-alive pings (default: 30)
    """
    ping_count = 0
    while not stop_event.is_set():
        try:
            # Simple operation to keep session alive
            df = spark.sql("SELECT 1")
            df_count = df.count()
            ping_count += 1
            print(f"[KEEP-ALIVE] ✓ Ping #{ping_count} successful (interval: {interval}s)")
            time.sleep(interval)
        except Exception as e:
            if VERBOSE_LOG:
                print(f"[KEEP-ALIVE] ✗ Warning: {e}")
            time.sleep(interval)

In [0]:
class UnityCatalogEnumerator:
    """Unity Catalog enumeration with threading and limits."""
    
    def __init__(self, base_url: str, headers: Dict[str, str]):
        self.base_url = base_url
        self.headers = headers
    
    def list_tables_for_schema(self, catalog: str, schema: str) -> Tuple[int, int, int]:
        """Return (tables, managed, external) for catalog.schema"""
        try:
            r = requests.get(
                f"{self.base_url}/api/2.1/unity-catalog/tables",
                headers=self.headers,
                params={"catalog_name": catalog, "schema_name": schema},
                timeout=HTTP_TIMEOUT_SEC
            )
            if r.status_code == 200:
                tables = r.json().get("tables", [])
                managed = sum(1 for t in tables if t.get("table_type") == "MANAGED")
                external = sum(1 for t in tables if t.get("table_type") == "EXTERNAL")
                return len(tables), managed, external
        except Exception as e:
            log(f"[UC-ERROR] {catalog}.{schema}: {e}")
        return 0, 0, 0
    
    def list_tables_for_schema_with_data(self, catalog: str, schema: str) -> List[Dict[str, any]]:
        """Return list of table metadata for catalog.schema"""
        try:
            r = requests.get(
                f"{self.base_url}/api/2.1/unity-catalog/tables",
                headers=self.headers,
                params={"catalog_name": catalog, "schema_name": schema},
                timeout=HTTP_TIMEOUT_SEC
            )
            if r.status_code == 200:
                tables = r.json().get("tables", [])
                # Add catalog and schema context to each table
                for tbl in tables:
                    tbl["_catalog"] = catalog
                    tbl["_schema"] = schema
                return tables
        except Exception as e:
            log(f"[UC-ERROR] {catalog}.{schema}: {e}")
        return []
    
    def get_catalogs(self, allowlist: List[str] = None, catalog_limit: int = 0) -> List[str]:
        """Get list of catalogs to scan."""
        try:
            rr = requests.get(f"{self.base_url}/api/2.1/unity-catalog/catalogs", 
                            headers=self.headers, timeout=HTTP_TIMEOUT_SEC)
            rr.raise_for_status()
            catalogs = [c["name"] for c in rr.json().get("catalogs", [])]
            
            if allowlist:
                # Special case: if hive_metastore is in allowlist, always include it
                # (it's not returned by catalogs API but exists in legacy metastore)
                if "hive_metastore" in allowlist and "hive_metastore" not in catalogs:
                    catalogs.append("hive_metastore")
                    log("[UC] Added legacy hive_metastore to catalog list")
                
                catset = set(catalogs)
                catalogs = [c for c in allowlist if c in catset]
            
            catalogs = sorted(catalogs)
            
            if catalog_limit and catalog_limit > 0:
                catalogs = catalogs[:catalog_limit]
                
            return catalogs
        except Exception as e:
            log(f"[UC-ERROR] Failed to get catalogs: {e}")
            return []
    
    def get_schemas_for_catalog(self, catalog: str, schema_limit: int = 0) -> List[str]:
        """Get list of schemas for a catalog."""
        try:
            rs = requests.get(
                f"{self.base_url}/api/2.1/unity-catalog/schemas",
                headers=self.headers,
                params={"catalog_name": catalog},
                timeout=HTTP_TIMEOUT_SEC
            )
            if rs.status_code != 200:
                log(f"[UC] Skip catalog {catalog} (status={rs.status_code})")
                return []
            
            schemas = [s["name"] for s in rs.json().get("schemas", [])]
            schemas = sorted(schemas)
            
            if schema_limit and schema_limit > 0:
                schemas = schemas[:schema_limit]
                
            return schemas
        except Exception as e:
            log(f"[UC-ERROR] Failed to get schemas for {catalog}: {e}")
            return []
    
    def enumerate_unity_catalog(
        self,
        enable: bool = True,
        allowlist: List[str] = None,
        catalog_limit: int = 0,
        schema_limit_per_catalog: int = 0,
        max_workers: int = 20,
        spark = None,
        write_callback = None
    ) -> Tuple[Dict[str, int], Dict[str, List[Dict[str, any]]]]:
        """
        Enumerate Unity Catalog with configurable limits and threading.
        
        Args:
            write_callback: Optional function(table_type, records) to write data per catalog
        
        Returns tuple of (counts, raw_data):
        - counts: dict with uc_catalogs, uc_schemas, uc_tables, managed_tables, external_tables
        - raw_data: dict with "schemas" and "tables" lists containing full metadata
        """
        banner("3/4 Unity Catalog Enumeration")
        out = {
            "uc_catalogs": 0,
            "uc_schemas": 0, 
            "uc_tables": 0,
            "managed_tables": 0,
            "external_tables": 0
        }
        
        # Collect actual data
        raw_data = {
            "schemas": [],
            "tables": []
        }
        
        if not enable:
            print("[UC] Skipped (UC_ENABLE=False).")
            return out, raw_data

        # Start keep-alive thread to prevent Spark session timeout
        stop_keep_alive = threading.Event()
        keep_alive_thread = None
        
        if spark:
            keep_alive_thread = threading.Thread(
                target=keep_spark_alive,
                args=(spark, stop_keep_alive, SPARK_KEEPALIVE_INTERVAL),
                daemon=True
            )
            keep_alive_thread.start()
            print(f"[UC] Spark keep-alive thread started ({SPARK_KEEPALIVE_INTERVAL}s interval)")

        try:
            allowlist = allowlist or []
            
            # Get catalogs to scan
            catalogs = self.get_catalogs(allowlist, catalog_limit)
            if not catalogs:
                print("[UC] No catalogs found or accessible.")
                return out, raw_data
                
            out["uc_catalogs"] = len(catalogs)

            print(f"[UC] Scope: allowlist={allowlist or 'ALL'}, cat_limit={catalog_limit or 'ALL'}, schema_limit={schema_limit_per_catalog or 'ALL'}")
            print(f"[UC] Catalogs to scan: {len(catalogs)} → {catalogs[:5]}{' …' if len(catalogs)>5 else ''}")

            total_schemas = total_tables = managed = external = 0

            with ThreadPoolExecutor(max_workers=max_workers) as ex:
                # Process catalogs ONE AT A TIME for better memory management
                for cat_idx, cat in enumerate(catalogs, 1):
                    schemas = self.get_schemas_for_catalog(cat, schema_limit_per_catalog)
                    total_schemas += len(schemas)
                    print(f"[UC] [{cat_idx}/{len(catalogs)}] {cat}: {len(schemas)} schemas")
                    
                    # Queue table enumeration for THIS catalog only
                    catalog_futures = []
                    for s in schemas:
                        raw_data["schemas"].append({
                            "catalog_name": cat,
                            "schema_name": s,
                            "full_name": f"{cat}.{s}"
                        })
                        catalog_futures.append(ex.submit(self.list_tables_for_schema_with_data, cat, s))
                        time.sleep(0.002)  # Small delay to avoid overwhelming the API
                    
                    # Process results for THIS catalog only
                    catalog_tables = []
                    for f in as_completed(catalog_futures):
                        table_data = f.result()
                        if table_data:
                            catalog_tables.extend(table_data)
                            
                            # Count stats
                            for tbl in table_data:
                                total_tables += 1
                                if tbl.get("table_type") == "MANAGED":
                                    managed += 1
                                elif tbl.get("table_type") == "EXTERNAL":
                                    external += 1
                    
                    # Write immediately for this catalog (if callback provided)
                    if catalog_tables:
                        if write_callback:
                            print(f"[UC] Writing {len(catalog_tables)} tables for catalog '{cat}'")
                            write_callback("databricks_table", catalog_tables)
                            del catalog_tables  # Explicit memory cleanup
                        else:
                            # Batch mode - collect all
                            raw_data["tables"].extend(catalog_tables)
                            del catalog_tables
                    
                    # Periodic garbage collection (every 100 catalogs)
                    if cat_idx % 100 == 0:
                        gc.collect()
                        print(f"[UC] Memory checkpoint at catalog {cat_idx}/{len(catalogs)}")
                    
                    if cat_idx % 5 == 0:
                        print(f"[UC] Progress: {cat_idx}/{len(catalogs)} catalogs, {total_tables} tables")

            out["uc_schemas"] = total_schemas
            out["uc_tables"] = total_tables
            out["managed_tables"] = managed
            out["external_tables"] = external

            # Build scope description
            scope = []
            if allowlist: 
                scope.append(f"allow={allowlist}")
            if catalog_limit: 
                scope.append(f"cat_limit={catalog_limit}")
            if schema_limit_per_catalog: 
                scope.append(f"schema_limit={schema_limit_per_catalog}")
            desc = ", ".join(scope) if scope else "full-scan"
            
            print(f"[UC] Done ({desc}) → {out['uc_catalogs']} catalogs, {out['uc_schemas']} schemas, {out['uc_tables']} tables")
            print(f"[UC] Collected {len(raw_data['schemas'])} schema records, {len(raw_data['tables'])} table records")
            return out, raw_data
        
        finally:
            # Stop keep-alive thread
            if spark and keep_alive_thread:
                stop_keep_alive.set()
                keep_alive_thread.join(timeout=5)
                print("[UC] Spark keep-alive thread stopped")


In [0]:
# Convenience function for backward compatibility
def enumerate_uc(
    base_url: str,
    headers: Dict[str, str],
    enable: bool = UC_ENABLE,
    allowlist: List[str] = None,
    catalog_limit: int = UC_CATALOG_LIMIT,
    schema_limit_per_catalog: int = UC_SCHEMA_LIMIT_PER_CATALOG,
    max_workers: int = UC_MAX_WORKERS,
    spark = None,
    write_callback = None
) -> Tuple[Dict[str, int], Dict[str, List[Dict[str, any]]]]:
    """
    Wrapper for Unity Catalog enumeration.
    
    Args:
        spark: SparkSession to keep alive during long enumeration
        write_callback: Optional function(table_type, records) to write data per catalog
    
    Returns tuple of (counts, raw_data):
    - counts: dict with catalog/schema/table counts
    - raw_data: dict with "schemas" and "tables" lists containing full metadata
    """
    enumerator = UnityCatalogEnumerator(base_url, headers)
    return enumerator.enumerate_unity_catalog(
        enable=enable,
        allowlist=allowlist or UC_CATALOG_ALLOWLIST,
        catalog_limit=catalog_limit,
        schema_limit_per_catalog=schema_limit_per_catalog,
        max_workers=max_workers,
        spark=spark,
        write_callback=write_callback
    )