In [0]:
# #Run cell on cluster restart or if receiving error:
# #AttributeError: 'EMR' object has no attribute 'create_persistent_app_ui'
# %pip install --upgrade boto3 botocore
# %restart_python

## Databricks Spark Event Log Analyzer

This script analyzes Databricks clusters to extract performance metrics from Spark History Servers. It discovers Databricks clusters, connects to managed Spark History Server UIs, fetches application, job, stage, and SQL query data, and then processes this information into Spark DataFrames for performance analysis and optimization insights.

## Required Setup

i) Create a secret for token and then fetch token in code (in this example we leverage dbutils with secret scope shscreds but you can name these whatever)

ii) Create a secret for data plane URL and then fetch token in code (further instructions in readme)

iii) Create a secret for DATAPLANE_DOMAIN_DBAUTH and then fetch token in code (further instructions in readme)


In [0]:
import datetime
from datetime import timedelta, date
import logging, os, time
from logging.handlers import RotatingFileHandler

# ----------------------------------------------------------------------
# Configuration Parameters
# ----------------------------------------------------------------------
# Parse and validate configuration
dbutils.widgets.text("timeout_seconds", "300", "Request Timeout (seconds)")
dbutils.widgets.text("max_applications", "10", "Max Applications to Analyze per Cluster")
dbutils.widgets.dropdown("environment", "dev", ["dev", "prod"], "Environment (dev/prod)")
dbutils.widgets.text("catalog_name", "", "Catalog (required)")
dbutils.widgets.text("schema_name", "", "Schema")
dbutils.widgets.text("volume_name", "profiler_logs_volume", "Volume for logs")
dbutils.widgets.text("max_clusters", "200", "Max Clusters to Analyze")
dbutils.widgets.text("batch_size", "10", "Batch Size (clusters to process concurrently)")
dbutils.widgets.text("batch_delay_seconds", "2", "Delay Between Batches (seconds)")
dbutils.widgets.text("max_endpoint_failures", "3", "Max Endpoint Failures per Endpoint Type")
dbutils.widgets.text("include_tasks", "false", "Set to true if you want to include task level metrics")

TIMEOUT_SECONDS = int(dbutils.widgets.get("timeout_seconds") or "300")
MAX_APPLICATIONS = int(dbutils.widgets.get("max_applications") or "50")
MAX_CLUSTERS_RAW = int(dbutils.widgets.get("max_clusters") or "200")
MAX_CLUSTERS = MAX_CLUSTERS_RAW / 20
ENVIRONMENT = dbutils.widgets.get("environment").strip()
CATALOG_NAME = dbutils.widgets.get("catalog_name").strip()
SCHEMA_NAME = dbutils.widgets.get("schema_name").strip() or DEFAULT_SCHEMA_NAME
VOLUME_NAME = dbutils.widgets.get("volume_name").strip() or "profiler_logs_volume"
BATCH_SIZE = int(dbutils.widgets.get("batch_size") or "10")
BATCH_DELAY_SECONDS = int(dbutils.widgets.get("batch_delay_seconds") or "2")
MAX_ENDPOINT_FAILURES = int(dbutils.widgets.get("max_endpoint_failures") or "3")
TASK_BINARY = dbutils.widgets.get("include_tasks").strip().lower()

# UC Validation
if not CATALOG_NAME:
    raise ValueError("catalog widget must point to an existing catalog")

spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG_NAME}.{SCHEMA_NAME}")
spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG_NAME}.{SCHEMA_NAME}.{VOLUME_NAME}")
VOLUME_BASE_PATH = f"/Volumes/{CATALOG_NAME}/{SCHEMA_NAME}/{VOLUME_NAME}"
LOG_DIR = f"{VOLUME_BASE_PATH}"

# Create logs directory
try:
    os.makedirs(LOG_DIR, exist_ok=True)
except Exception:
    dbutils.fs.mkdirs(LOG_DIR)

# Configure logging
run_id = time.strftime("%Y%m%d-%H%M%S")
LOG_FILE = f"{LOG_DIR}/run-{run_id}.txt"

# Reset handlers to avoid duplicates on re-run
for h in logging.root.handlers[:]:
    logging.root.removeHandler(h)
    h.close()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
    handlers=[
        logging.StreamHandler(),  # notebook output
        RotatingFileHandler(LOG_FILE, mode='a', encoding="utf-8"),  # single full file
    ],
)

logging.getLogger("py4j").setLevel(logging.ERROR)
logging.getLogger("databricks").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)

# SHS Retry Configuration
SHS_RETRY_CONFIG = {
    "retry_settings": {
        "max_retries": 3,
        "initial_delay_seconds": 5.0,
        "max_delay_seconds": 120.0,
        "backoff_factor": 2.0,
        "enable_jitter": True,
        "jitter_max": 1.0
    },
    "timeout_settings": {
        "connection_timeout": 30,
        "read_timeout": 180,
        "total_timeout": 300
    },
    "throttling_settings": {
        "request_delay_seconds": 1.0,
        "enable_adaptive_throttling": True,
        "consecutive_error_threshold": 5,
        "adaptive_delay_multiplier": 1.5
    },
    "spark_history_server_errors": {
        "retryable_http_codes": [429, 500, 502, 503, 504],
        "non_retryable_codes": [400, 401, 403, 404],
        "connection_errors": ["ConnectionError", "Timeout", "HTTPError"]
    },
    "endpoint_specific_settings": {
        "stages": {
            "max_retries": 2,
            "read_timeout": 600
        },
        "jobs": {
            "max_retries": 2,
            "read_timeout": 600
        },
        "sql": {
            "max_retries": 2,
            "read_timeout": 600
        },
        "executors": {
            "max_retries": 2,
            "read_timeout": 300
        },
        "applications": {
            "max_retries": 2,
            "read_timeout": 300
        }
    },
    "logging": {
        "enable_retry_logging": True,
        "enable_success_logging": True,
        "enable_throttle_logging": True,
        "log_shs_url": True
    }
}

# Log final configuration
print("Configuration:")
print(f" Environment: {ENVIRONMENT}")
print(f" Timeout: {TIMEOUT_SECONDS} seconds")
print(f" Max Applications per Cluster: {MAX_APPLICATIONS}")
print(f" Max Clusters to Analyze: {MAX_CLUSTERS}")
print(f" Catalog: {CATALOG_NAME}")
print(f" Schema: {SCHEMA_NAME}")
print(f" Log file: {LOG_FILE}")
print(f" Batch Size: {BATCH_SIZE} clusters")
print(f" Batch Delay: {BATCH_DELAY_SECONDS} seconds")
print(f" Max Endpoint Failures: {MAX_ENDPOINT_FAILURES}")


### DBX Cluster Discovery

In [0]:
import time
from datetime import datetime
from typing import List, Dict, Optional
from botocore.exceptions import ClientError

class DBXClusterDiscovery:
    """Discovery and management of EMR clusters."""

    def __init__(self, region: Optional[str] = None):
        """
        Initialize the EMR cluster discovery client.

        :param region: AWS region.
        :raises TypeError: If region is not a non-empty string.
        """
        service = 'Databricks'

    def discover_clusters(
        self,
        states: Optional[List[str]] = None,
        name_filter: Optional[str] = None,
        max_clusters: int = 100,
        created_after: Optional[datetime] = None,
        created_before: Optional[datetime] = None) -> List[Dict]:
        """
        Discover EMR clusters based on criteria.
        """

        try:
        
            token = dbutils.secrets.get(scope="shscreds", key="token")
            databricks_workspace_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
            token_header = {'Authorization': 'Bearer {0}'.format(token)}
            initurl = f'{databricks_workspace_url}/api/2.1/clusters/list?filter_by.cluster_sources=JOB&filter_by.cluster_states=TERMINATED'
            allclusters = []
            nextpage = 'init'
            i = 0


            while nextpage != "" and i < MAX_CLUSTERS:
                if nextpage == 'init':
                    resp = requests.get(initurl, headers=token_header)
                    respjson = resp.json()
                else:
                    url = f"{databricks_workspace_url}/api/2.1/clusters/list?filter_by.cluster_sources=JOB&filter_by.cluster_states=TERMINATED&page_token={nextpage}"
                    resp = requests.get(url, headers=token_header)
                    respjson = resp.json()
                pageofclusters = respjson.get('clusters')
                for cluster in pageofclusters:
                    clusterid = cluster.get('cluster_id')
                    clustername = cluster.get('cluster_name')
                    clustersource = cluster.get('cluster_source')
                    state = cluster.get('state')
                    sparkcontextid = cluster.get('spark_context_id')
                    tmpdict = {'cluster_id': clusterid, 'cluster_name': clustername, 'cluster_source': clustersource, 'state': state, 'spark_context_id': sparkcontextid}
                    allclusters.append(tmpdict)
                nextpage = respjson.get('next_page_token')
                i = i + 1

            logger.info("✅ Discovered %s clusters", len(allclusters))
            return allclusters
        except ClientError as e:
            logger.error("❌ Failed to discover clusters: %s", e.response["Error"]["Message"], exc_info=True)
            raise


### Spark History Server REST Interaction

In [0]:
import json
import random
import time
from typing import List, Any, Dict, Optional
import requests
import logging

logger = logging.getLogger(__name__)

class SparkHistoryServerClient:
    """Client for interacting with Spark History Server REST API."""

    def __init__(self, base_url: str, session: requests.Session, cookies: str):
        """
        Initialize the Spark History Server client.

        :param base_url: Base URL for the Spark History Server.
        :param session: Configured HTTP session with authentication.
        """
        if not base_url or not isinstance(base_url, str):
            raise ValueError("base_url must be a non-empty string.")
        if not isinstance(session, requests.Session):
            raise TypeError("session must be a requests.Session object.")
            
        self.base_url = base_url
        self.session = session
        self.cookies = cookies
        self.api_base = f"{base_url}api/v1"
        self.retry_config = SHS_RETRY_CONFIG
        self.consecutive_errors = {}  # Track consecutive errors per endpoint
        self.last_request_time = 0  # For throttling

    def _calculate_jitter(self, delay: float) -> float:
        """Calculate jitter for backoff delay."""
        if not self.retry_config["retry_settings"]["enable_jitter"]:
            return delay
        
        jitter_max = self.retry_config["retry_settings"]["jitter_max"]
        jitter = random.uniform(0, jitter_max)
        return delay + jitter

    def _get_endpoint_type(self, endpoint: str) -> str:
        """Extract endpoint type from endpoint path for configuration lookup."""
        if '/sql' in endpoint:
            return 'sql'
        elif '/stages' in endpoint:
            return 'stages'
        elif '/jobs' in endpoint:
            return 'jobs'
        elif '/allexecutors' in endpoint:
            return 'executors'
        elif 'applications' in endpoint and not any(x in endpoint for x in ['/jobs', '/stages', '/sql', '/allexecutors']):
            return 'applications'
        else:
            return 'default'

    def _apply_throttling(self, endpoint_type: str):
        """Apply throttling based on configuration and consecutive errors."""
        config = self.retry_config
        throttling = config["throttling_settings"]
        
        # Base request delay
        base_delay = throttling["request_delay_seconds"]
        
        # Adaptive throttling based on consecutive errors
        if throttling["enable_adaptive_throttling"]:
            consecutive_errors = self.consecutive_errors.get(endpoint_type, 0)
            if consecutive_errors >= throttling["consecutive_error_threshold"]:
                adaptive_delay = base_delay * (throttling["adaptive_delay_multiplier"] ** (consecutive_errors - throttling["consecutive_error_threshold"] + 1))
                base_delay = min(adaptive_delay, config["retry_settings"]["max_delay_seconds"])
                
                if config["logging"]["enable_throttle_logging"]:
                    logger.info("Applying adaptive throttling for %s: %d consecutive errors, delay: %.2fs", 
                              endpoint_type, consecutive_errors, base_delay)
        
        # Ensure minimum time between requests
        time_since_last = time.time() - self.last_request_time
        if time_since_last < base_delay:
            sleep_time = base_delay - time_since_last
            if config["logging"]["enable_throttle_logging"]:
                logger.info("Throttling request for %s: sleeping %.2fs", endpoint_type, sleep_time)
            time.sleep(sleep_time)
        
        self.last_request_time = time.time()

    def _is_retryable_error(self, exception: Exception) -> bool:
        """Determine if an error is retryable based on configuration."""
        error_config = self.retry_config["spark_history_server_errors"]
        
        if isinstance(exception, requests.exceptions.HTTPError):
            status_code = exception.response.status_code
            if status_code in error_config["non_retryable_codes"]:
                return False
            return status_code in error_config["retryable_http_codes"]
        
        # Check connection errors
        exception_name = type(exception).__name__
        return exception_name in error_config["connection_errors"]

    def _make_request(self, endpoint: str, params: Optional[Dict] = None, max_retries: int = None) -> Any:
        """
        Make a REST API request with enhanced retry logic based on SHS_RETRY_CONFIG.

        :param endpoint: The API endpoint to call (e.g., 'applications/app-123/jobs').
        :param params: A dictionary of query parameters for the request.
        :param max_retries: Optional override for max retries (uses config if None).
        :returns: The JSON response from the API.
        :raises requests.exceptions.RequestException: If the request fails after all retries.
        """
        url = f"{self.api_base}/{endpoint}"
        #print(f"!!!!! URL {url} HERE")
        endpoint_type = self._get_endpoint_type(endpoint)
        config = self.retry_config
        cookies = self.cookies
        
        # Get endpoint-specific settings
        endpoint_settings = config["endpoint_specific_settings"].get(endpoint_type, {})
        
        # Determine max retries (endpoint-specific > parameter > global config)
        if max_retries is None:
            max_retries = endpoint_settings.get("max_retries", config["retry_settings"]["max_retries"])
        
        # Determine timeouts
        connection_timeout = config["timeout_settings"]["connection_timeout"]
        read_timeout = endpoint_settings.get("read_timeout", config["timeout_settings"]["read_timeout"])
        timeout = (connection_timeout, read_timeout)
        
        if config["logging"]["log_shs_url"] and config["logging"]["enable_retry_logging"]:
            logger.info("Making request to: %s (endpoint_type: %s, max_retries: %d)", url, endpoint_type, max_retries)
        
        # Apply throttling before making request
        self._apply_throttling(endpoint_type)
        
        last_exception = None
        
        # The retry loop starts with attempt 0 (the first try)
        for attempt in range(max_retries + 1):
            if attempt > 0:
                # Calculate backoff delay with jitter
                base_delay = config["retry_settings"]["initial_delay_seconds"] * (
                    config["retry_settings"]["backoff_factor"] ** (attempt - 1)
                )
                delay = min(base_delay, config["retry_settings"]["max_delay_seconds"])
                delay_with_jitter = self._calculate_jitter(delay)
                
                if config["logging"]["enable_retry_logging"]:
                    logger.info("Retrying %s in %.2f seconds (attempt %d/%d)...", 
                              endpoint, delay_with_jitter, attempt, max_retries)
                
                time.sleep(delay_with_jitter)

            try:
                response = self.session.get(url, cookies=cookies, params=params, timeout=timeout, allow_redirects=True)
                #response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
                
                # Reset consecutive errors on success
                self.consecutive_errors[endpoint_type] = 0
                
                if config["logging"]["enable_success_logging"]:
                    logger.debug("✅ Successfully retrieved data from %s (attempt %d)", endpoint, attempt + 1)
                
                return response.json()
                
            except Exception as e:
                last_exception = e
                
                # Track consecutive errors for adaptive throttling
                self.consecutive_errors[endpoint_type] = self.consecutive_errors.get(endpoint_type, 0) + 1
                
                # Check if error is retryable
                #if not self._is_retryable_error(e):
                #    if config["logging"]["enable_retry_logging"]:
                #        logger.error("❌ Non-retryable error for %s: %s", url, str(e))
                #    raise e
                
                #if config["logging"]["enable_retry_logging"]:
                #    logger.warning("Retryable error on attempt %d for %s: %s", attempt + 1, url, str(e))
                
                # If this was the last attempt, don't continue
                if attempt == max_retries:
                    break

        # All retries exhausted
        if config["logging"]["enable_retry_logging"]:
            logger.error("All %d retry attempts failed for %s.", max_retries + 1, url)
        
        raise last_exception

    def get_applications(self, status: Optional[str] = None, limit: int = 100) -> List[Dict]:
        logger.info("Fetching applications (status: %s, limit: %s)", status, limit)
        params = {'limit': limit}
        if status:
            params['status'] = status
        return self._make_request("applications", params)

    def get_application_details(self, app_id: str) -> Dict:
        logger.info("Fetching application details for: %s", app_id)
        return self._make_request(f"applications/{app_id}")

    def get_application_jobs(self, app_id: str, attempt_id: Optional[str] = None, status: Optional[str] = None) -> List[Dict]:
        endpointog = f"applications/{app_id}/{attempt_id}/jobs" if attempt_id else f"applications/{app_id}/jobs"
        endpoint = f"applications/{app_id}/jobs"
        params = {'status': status} if status else {}
        return self._make_request(endpoint, params)

    def get_application_stages(self, app_id: str, attempt_id: Optional[str] = None, status: Optional[str] = None) -> List[Dict]:
        endpointog = f"applications/{app_id}/{attempt_id}/stages" if attempt_id else f"applications/{app_id}/stages"
        endpoint = f"applications/{app_id}/stages"
        params = {'details': 'true', 'withSummaries': 'true'}
        if status:
            params['status'] = status
        return self._make_request(endpoint, params)

    def get_stage_tasks(self, app_id: str, attempt_id: str, stage_id: int, stage_attempt: int = 0) -> List[Dict]:
        endpoint = f"applications/{app_id}/{attempt_id}/stages/{stage_id}/{stage_attempt}/taskList"
        return self._make_request(endpoint)

    def get_stage_task_summary(self, app_id: str, attempt_id: str, stage_id: int, stage_attempt: int = 0) -> Dict:
        endpoint = f"applications/{app_id}/{attempt_id}/stages/{stage_id}/{stage_attempt}/taskSummary"
        params = {'quantiles': "0.0,0.25,0.5,0.75,1.0"}
        return self._make_request(endpoint, params)

    def get_application_executors(self, app_id: str, attempt_id: Optional[str] = None) -> List[Dict]:
        endpoint = f"applications/{app_id}/{attempt_id}/allexecutors" if attempt_id else f"applications/{app_id}/allexecutors"
        return self._make_request(endpoint)

    def get_application_sql_queries(self, app_id: str, attempt_id: Optional[str] = None) -> List[Dict]:
        endpointog = f"applications/{app_id}/{attempt_id}/sql" if attempt_id else f"applications/{app_id}/sql"
        endpoint = f"applications/{app_id}/sql"
        params = {'details': 'true', 'planDescription': 'true'}
        return self._make_request(endpoint, params)


### Metrics Analysis Functions

In [0]:
import json
import logging
from typing import List, Dict, Tuple, Optional, Any

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import (
    StructType, StructField, StringType, LongType, DoubleType, BooleanType
)

logger = logging.getLogger(__name__)


class SparkMetricsAnalyzer:
    """Analyzer for Spark application metrics and performance data."""

    def __init__(self, spark: SparkSession):
        """
        Initialize the metrics analyzer.

        :param spark: Spark session instance for DataFrame creation.
        :raises TypeError: If spark is not a SparkSession instance.
        """
        #if not isinstance(spark, SparkSession):
        #    raise TypeError("Parameter 'spark' must be a SparkSession instance")
        self.spark = spark

    def analyze_application_performance(self, app_data: Dict, latest_attempt_id: str) -> Dict:
        """
        Analyze performance metrics for a single application.

        :param app_data: Application data from Spark History Server.
        :param latest_attempt_id: The latest attempt ID to extract data from.
        :returns: Performance analysis results dictionary.
        """
        latest_attempt_info = next((a for a in app_data.get('attempts', []) if a.get('attemptId') == latest_attempt_id), {})
        source_data = latest_attempt_info or app_data
        duration_ms = source_data.get('duration', 0)

        return {
            'application_id': app_data.get('id'),
            'application_name': app_data.get('name', 'Unknown'),
            'duration_ms': duration_ms,
            'duration_minutes': round(duration_ms / 60000, 2),
            'start_time': source_data.get('startTime'),
            'end_time': source_data.get('endTime'),
            'spark_version': latest_attempt_info.get('appSparkVersion') or app_data.get('sparkVersion', 'Unknown'),
        }

    def analyze_job_performance(self, jobs: List[Dict]) -> List[Dict]:
        """
        Analyze performance metrics for jobs.

        :param jobs: List of job data dictionaries.
        :returns: List of job performance analysis dictionaries.
        """
        job_analysis = []
        for job in jobs:
            num_tasks = job.get('numTasks', 0)
            num_completed = job.get('numCompletedTasks', 0)
            analysis = {
                'job_id': job.get('jobId'),
                'job_name': job.get('name', 'Unknown'),
                'status': job.get('status', 'UNKNOWN'),
                'submission_time': job.get('submissionTime'),
                'completion_time': job.get('completionTime'),
                'num_tasks': num_tasks,
                'num_completed_tasks': num_completed,
                'num_failed_tasks': job.get('numFailedTasks', 0),
                'stage_ids': str(job.get('stageIds', [])),
                'task_success_rate': round((num_completed / num_tasks) * 100.0, 2) if num_tasks > 0 else 0.0,
            }
            job_analysis.append(analysis)
        return job_analysis

    def analyze_stage_performance(self, stages: List[Dict]) -> List[Dict]:
        """
        Analyze performance metrics for stages.

        :param stages: List of stage data dictionaries.
        :returns: List of stage performance analysis dictionaries.
        """
        stage_analysis = []
        for stage in stages:
            analysis = {
                'stage_id': stage.get('stageId'),
                'stage_name': stage.get('name') or 'Unknown',
                'status': stage.get('status') or 'UNKNOWN',
                'num_tasks': stage.get('numTasks') or 0,
                'num_active_tasks': stage.get('numActiveTasks') or 0,
                'num_complete_tasks': stage.get('numCompleteTasks') or 0,
                'num_failed_tasks': stage.get('numFailedTasks') or 0,
                'executor_run_time': stage.get('executorRunTime') or 0,
                'executor_cpu_time': stage.get('executorCpuTime') or 0,
                'submission_time': stage.get('submissionTime'),
                'first_task_launched_time': stage.get('firstTaskLaunchedTime'),
                'completion_time': stage.get('completionTime'),
                'input_bytes': stage.get('inputBytes') or 0,
                'output_bytes': stage.get('outputBytes') or 0,
                'shuffle_read_bytes': stage.get('shuffleReadBytes') or 0,
                'shuffle_write_bytes': stage.get('shuffleWriteBytes') or 0,
                'memory_bytes_spilled': stage.get('memoryBytesSpilled') or 0,
                'disk_bytes_spilled': stage.get('diskBytesSpilled') or 0,
                'task_completion_rate': 0.0,
                'avg_executor_run_time_per_task': 0.0,
                'total_data_processed_mb': 0.0,
                'shuffle_data_mb': 0.0,
                'cluster_id': '',  # Will be populated by caller
                'cluster_name': '',  # Will be populated by caller
                'application_id': ''  # Will be populated by caller
            }

            # Calculate efficiency metrics with safe division
            num_tasks = analysis['num_tasks']
            num_complete = analysis['num_complete_tasks']
            executor_run_time = analysis['executor_run_time']
            
            if num_tasks > 0:
                analysis['task_completion_rate'] = round((num_complete / num_tasks) * 100.0, 2)
                if executor_run_time > 0:
                    analysis['avg_executor_run_time_per_task'] = round(executor_run_time / num_tasks, 2)

            # Calculate data processing metrics (convert bytes to MB)
            input_bytes = analysis['input_bytes']
            output_bytes = analysis['output_bytes']
            shuffle_read_bytes = analysis['shuffle_read_bytes']
            shuffle_write_bytes = analysis['shuffle_write_bytes']
            
            analysis['total_data_processed_mb'] = round((input_bytes + output_bytes) / (1024 * 1024), 2)
            analysis['shuffle_data_mb'] = round((shuffle_read_bytes + shuffle_write_bytes) / (1024 * 1024), 2)

            stage_analysis.append(analysis)

        return stage_analysis

    def analyze_task_performance(self, tasks: List[Dict]) -> List[Dict]:
        """
        Analyze performance metrics for tasks.

        :param tasks: List of task data dictionaries.
        :returns: List of task performance analysis dictionaries.
        """
        task_analysis = []
        for task in tasks:
            analysis = {
                'task_id': task.get('taskId'),
                'index': task.get('index'),
                'attempt': task.get('attempt'),
                'launch_time': task.get('launchTime'),
                'duration': task.get('duration'),
                'executor_id': task.get('executorId'),
                'host': task.get('host'),
                'status': task.get('status'),
                'task_locality': task.get('taskLocality'),
                'speculative': task.get('speculative'),
                'stage_id': task.get('stage_id'),
                'stage_attempt_id': task.get('stage_attempt_id'),
            }
            task_analysis.append(analysis)
        return task_analysis

    def analyze_sql_queries(self, sql_queries: List[Dict]) -> List[Dict]:
        """
        Analyze SQL query metrics.

        :param sql_queries: List of SQL query data dictionaries.
        :returns: List of processed SQL query analysis dictionaries.
        """
        sql_analysis = []
        for query in sql_queries:
            analysis = {
                'sql_id': query.get("id"),
                'description': query.get("description", "N/A"),
                'status': query.get("status"),
                'duration_ms': query.get("duration", 0),
                'submission_time': query.get("submissionTime"),
                'sql_raw_json': json.dumps(query),
            }
            sql_analysis.append(analysis)
        return sql_analysis

    def analyze_executor_performance(self, executors: List[Dict]) -> List[Dict]:
        """
        Analyze performance metrics for executors.

        :param executors: List of executor data dictionaries from Spark History Server.
        :returns: List of executor performance analysis dictionaries.
        """
        executor_analysis = []
        for executor in executors:
            analysis = {
                'executor_id': executor.get('id'),
                'host_port': executor.get('hostPort'),
                'is_active': executor.get('isActive'),
                'rdd_blocks': executor.get('rddBlocks'),
                'memory_used': executor.get('memoryUsed'),
                'disk_used': executor.get('diskUsed'),
                'total_cores': executor.get('totalCores'),
                'max_tasks': executor.get('maxTasks'),
                'active_tasks': executor.get('activeTasks'),
                'failed_tasks': executor.get('failedTasks'),
                'completed_tasks': executor.get('completedTasks'),
                'total_tasks': executor.get('totalTasks'),
                'total_duration': executor.get('totalDuration'),
                'total_gc_time': executor.get('totalGCTime'),
                'total_input_bytes': executor.get('totalInputBytes'),
                'total_shuffle_read': executor.get('totalShuffleRead'),
                'total_shuffle_write': executor.get('totalShuffleWrite'),
                'is_blacklisted': executor.get('isBlacklisted', False),
                'max_memory': executor.get('maxMemory'),
                'add_time': executor.get('addTime'),
                'executor_logs': json.dumps(executor.get('executorLogs', {})),
            }
            executor_analysis.append(analysis)
        return executor_analysis

    def analyze_task_summaries(self, task_summaries: List[Dict]) -> List[Dict]:
        """
        Analyze task summary metrics.

        :param task_summaries: List of task summary data dictionaries.
        :returns: List of processed task summary analysis dictionaries.
        """
        analysis_list = []
        for summary in task_summaries:
            analysis = {
                'application_id': summary.get('application_id'),
                'stage_id': summary.get('stage_id'),
                'stage_attempt_id': summary.get('stage_attempt_id'),
                'raw_json': json.dumps(summary),
            }
            analysis_list.append(analysis)
        return analysis_list

    def create_dynamic_dataframes(
        self,
        applications_analysis: List[Dict],
        jobs_analysis: List[Dict],
        stages_analysis: List[Dict],
        tasks_analysis: List[Dict],
        sql_analysis: List[Dict],
        executors_analysis: List[Dict],
        task_summaries_analysis: List[Dict]
        ) -> Tuple[Optional[DataFrame], Optional[DataFrame], Optional[DataFrame], Optional[DataFrame], Optional[DataFrame], Optional[DataFrame], Optional[DataFrame]]:
        """
        Create Spark DataFrames from analysis results with explicit, well-defined schemas.
        This version includes the 'attempt_id' field to uniquely identify data from each application run.
        """
        def create_df(data: List[Dict], schema: StructType, name: str) -> Optional[DataFrame]:
            if not data:
                logger.info("No data provided for %s DataFrame.", name)
                return None
            try:
                df = self.spark.createDataFrame(data, schema=schema)
                logger.info("✅ Created %s DataFrame with %d rows.", name, df.count())
                return df
            except Exception as e:
                logger.error("❌ Failed to create %s DataFrame: %s", name, str(e), exc_info=True)
                return None

        # --- Updated Schemas with attempt_id ---
        applications_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("application_name", StringType(), True), StructField("duration_ms", LongType(), True),
            StructField("duration_minutes", DoubleType(), True), StructField("start_time", StringType(), True),
            StructField("end_time", StringType(), True), StructField("spark_version", StringType(), True)
        ])
        jobs_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("job_id", LongType(), True), StructField("job_name", StringType(), True),
            StructField("status", StringType(), True), StructField("submission_time", StringType(), True),
            StructField("completion_time", StringType(), True), StructField("num_tasks", LongType(), True),
            StructField("num_completed_tasks", LongType(), True), StructField("num_failed_tasks", LongType(), True),
            StructField("stage_ids", StringType(), True), StructField("task_success_rate", DoubleType(), True)
        ])
        stages_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("stage_id", LongType(), True), StructField("stage_name", StringType(), True),
            StructField("status", StringType(), True), StructField("num_tasks", LongType(), True),
            StructField("num_active_tasks", LongType(), True), StructField("num_complete_tasks", LongType(), True),
            StructField("num_failed_tasks", LongType(), True), StructField("executor_run_time", LongType(), True),
            StructField("executor_cpu_time", LongType(), True), StructField("submission_time", StringType(), True),
            StructField("first_task_launched_time", StringType(), True), StructField("completion_time", StringType(), True),
            StructField("input_bytes", LongType(), True), StructField("output_bytes", LongType(), True),
            StructField("shuffle_read_bytes", LongType(), True), StructField("shuffle_write_bytes", LongType(), True),
            StructField("memory_bytes_spilled", LongType(), True), StructField("disk_bytes_spilled", LongType(), True),
            StructField("task_completion_rate", DoubleType(), True), StructField("avg_executor_run_time_per_task", DoubleType(), True),
            StructField("total_data_processed_mb", DoubleType(), True), StructField("shuffle_data_mb", DoubleType(), True)
        ])
        tasks_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("stage_id", LongType(), True), StructField("stage_attempt_id", LongType(), True),
            StructField("task_id", LongType(), True), StructField("index", LongType(), True),
            StructField("attempt", LongType(), True), StructField("launch_time", StringType(), True),
            StructField("duration", LongType(), True), StructField("executor_id", StringType(), True),
            StructField("host", StringType(), True), StructField("status", StringType(), True),
            StructField("task_locality", StringType(), True), StructField("speculative", BooleanType(), True)
        ])
        sql_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("sql_id", LongType(), True), StructField("description", StringType(), True),
            StructField("status", StringType(), True), StructField("duration_ms", LongType(), True),
            StructField("submission_time", StringType(), True), StructField("sql_raw_json", StringType(), True)
        ])
        executors_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("executor_id", StringType(), True), StructField("host_port", StringType(), True),
            StructField("is_active", BooleanType(), True), StructField("rdd_blocks", LongType(), True),
            StructField("memory_used", LongType(), True), StructField("disk_used", LongType(), True),
            StructField("total_cores", LongType(), True), StructField("max_tasks", LongType(), True),
            StructField("active_tasks", LongType(), True), StructField("failed_tasks", LongType(), True),
            StructField("completed_tasks", LongType(), True), StructField("total_tasks", LongType(), True),
            StructField("total_duration", LongType(), True), StructField("total_gc_time", LongType(), True),
            StructField("total_input_bytes", LongType(), True), StructField("total_shuffle_read", LongType(), True),
            StructField("total_shuffle_write", LongType(), True), StructField("is_blacklisted", BooleanType(), True),
            StructField("max_memory", LongType(), True), StructField("add_time", StringType(), True),
            StructField("executor_logs", StringType(), True)
        ])
        task_summaries_schema = StructType([
            StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True),
            StructField("application_id", StringType(), True), StructField("attempt_id", StringType(), True),
            StructField("stage_id", LongType(), True), StructField("stage_attempt_id", LongType(), True),
            StructField("raw_json", StringType(), True)
        ])
        
        # Create DataFrames
        apps_df = create_df(applications_analysis, applications_schema, "applications")
        jobs_df = create_df(jobs_analysis, jobs_schema, "jobs")
        stages_df = create_df(stages_analysis, stages_schema, "stages")
        tasks_df = create_df(tasks_analysis, tasks_schema, "tasks")
        sql_df = create_df(sql_analysis, sql_schema, "sql")
        executors_df = create_df(executors_analysis, executors_schema, "executors")
        task_summaries_df = create_df(task_summaries_analysis, task_summaries_schema, "task_summaries")

        return apps_df, jobs_df, stages_df, tasks_df, sql_df, executors_df, task_summaries_df



### Cluster Analyzer

In [0]:
import concurrent.futures
import json
import logging
from typing import List, Dict, Any, Optional
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql import SparkSession

logger = logging.getLogger(__name__)

# A private constant to limit the number of errors recorded per endpoint to avoid excessive memory usage.
_MAX_ERRORS_PER_ENDPOINT = 3

def _truncate_message(message: str, max_len: int = 500) -> str:
    """Truncates a message to a maximum length."""
    return message if len(message) <= max_len else message[:max_len-3] + "..."

def _classify_exception_code(exc: Exception) -> str:
    import requests
    try:
        from botocore.exceptions import ClientError
    except Exception:
        ClientError = tuple()

    msg = str(exc) if exc else ""

    # Request-level errors
    if isinstance(exc, (requests.exceptions.ReadTimeout, requests.exceptions.Timeout)):
        return "TIMEOUT"
    if isinstance(exc, requests.exceptions.ConnectionError):
        return "CONNECTION_ERROR"
    if isinstance(exc, requests.exceptions.HTTPError) and getattr(exc, "response", None):
        try:
            return f"HTTP_{exc.response.status_code}"
        except Exception:
            return "HTTP_ERROR"

    # AWS client errors
    if ClientError and isinstance(exc, ClientError):
        code = getattr(getattr(exc, "response", {}), "get", lambda *_: {})("Error", {}).get("Code")
        if code:
            return f"AWS_{code}"
        return "AWS_CLIENT_ERROR"

    # Persistent UI initialization errors
    if isinstance(exc, ValueError):
        if "Persistent App UI did not become ready" in msg:
            return "PERSISTENT_UI_NOT_READY"
        if "No presigned URL" in msg:
            return "PERSISTENT_UI_NO_PRESIGNED_URL"
        if "No persistent UI ID" in msg:
            return "PERSISTENT_UI_NO_ID"

    return "UNKNOWN_ERROR"

def _compose_failed_status(endpoint_errors: Dict[str, List[Dict[str, Any]]], endpoint_attempted: Dict[str, bool], fallback_message: Optional[str] = None) -> str:
    codes: List[str] = []
    for errs in (endpoint_errors or {}).values():
        for err in (errs or []):
            code = err.get("code")
            if code and code not in codes:
                codes.append(code)
    if codes:
        return f"FAILED({','.join(codes[:MAX_ENDPOINT_FAILURES])})"
    if any(endpoint_attempted.values()):
        if fallback_message:
            safe = _truncate_message(fallback_message).replace('(', '[').replace(')', ']')
            return f"FAILED({safe})"
    
        return "FAILED(UNKNOWN)"
    return "FAILED(NO_DATA)"

def _record_error(endpoint: str, fn: Any, exc: Exception, endpoint_errors: Dict[str, List]):
    """Records a structured error for a failed endpoint call."""
    now_iso = datetime.now().isoformat()
    
    # Determine a short code for the error type
    if isinstance(exc, (requests.exceptions.ReadTimeout, requests.exceptions.Timeout)):
        code = "TIMEOUT"
    elif isinstance(exc, requests.exceptions.ConnectionError):
        code = "CONNECTION_ERROR"
    elif isinstance(exc, requests.exceptions.HTTPError):
        code = f"HTTP_{exc.response.status_code}"
    else:
        code = "UNKNOWN_ERROR"

    # Defensive: avoid masking original error recording if response properties fail
    error_obj = {
        'code': code,
        'exception_type': type(exc).__name__,
        'message': _truncate_message(str(exc)),
        'api': getattr(fn, '__name__', 'unknown'),
        'endpoint_key': endpoint,
        'timestamp': now_iso
    }

    # Append with cap
    lst = endpoint_errors.setdefault(endpoint, [])
    if len(lst) < _MAX_ERRORS_PER_ENDPOINT:
        lst.append(error_obj)

def process_single_application(
    app_id: str,
    attempt_id: str,
    shs_client: Any,
    analyzer: Any,
    cluster_id: str,
    cluster_name: str,
    endpoint_errors: Dict[str, List[Dict[str, Any]]],
    endpoint_attempted: Dict[str, bool],
    endpoint_skipped_reasons: Dict[str, List[str]], # Kept for consistent function signature
    app_details: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
    """
    Processes a single Spark application attempt by fetching its data concurrently.
    This function uses a ThreadPoolExecutor to make parallel API calls for jobs, stages,
    executors, and SQL queries to significantly speed up data retrieval.
    """
    logger.info(" -> Concurrently analyzing application %s, attempt %s", app_id, attempt_id)
    app_results = {
        'applications': [], 'jobs': [], 'stages': [], 'tasks': [],
        'sql_queries': [], 'executors': [], 'task_summaries': []
    }

    def add_and_append(data_list: List[Dict], result_key: str):
        if not isinstance(data_list, list):
            logger.warning("⚠️ Data provided to add_and_append for key '%s' is not a list, skipping.", result_key)
            return
        for item in data_list:
            item['application_id'] = app_id
            item['attempt_id'] = attempt_id
            item['cluster_id'] = cluster_id
            item['cluster_name'] = cluster_name
        app_results[result_key].extend(item for item in data_list if isinstance(item, dict))

    def safe_call(endpoint_key: str, fn, *args, **kwargs):
        endpoint_attempted[endpoint_key] = True
        try:
            return fn(*args, **kwargs)
        except Exception as e:
            logger.error("❌ Final failure for endpoint '%s' (attempt %s) after all retries: %s", endpoint_key, attempt_id, str(e), exc_info=False)
            _record_error(endpoint_key, fn, e, endpoint_errors)
            return None

    # --- Stage 1: Analyze Application Performance (using pre-fetched details) ---
    try:
        perf_analysis = analyzer.analyze_application_performance(app_details, attempt_id)
        add_and_append([perf_analysis], 'applications')
    except Exception as e:
        logger.error("❌ Failed during 'applications' post-processing for attempt %s: %s", attempt_id, str(e), exc_info=True)
        _record_error('applications', analyzer.analyze_application_performance, e, endpoint_errors)

    # --- Stage 2: Concurrently Fetch Primary Data Endpoints ---
    fetched_data = {}
    with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="SHS_Primary") as executor:
        future_to_endpoint = {
            executor.submit(safe_call, 'jobs', shs_client.get_application_jobs, app_id, attempt_id, status='succeeded'): 'jobs',
            executor.submit(safe_call, 'stages', shs_client.get_application_stages, app_id, attempt_id, status='complete'): 'stages',
            executor.submit(safe_call, 'executors', shs_client.get_application_executors, app_id, attempt_id): 'executors',
            executor.submit(safe_call, 'sql', shs_client.get_application_sql_queries, app_id, attempt_id): 'sql'
        }
        for future in concurrent.futures.as_completed(future_to_endpoint):
            endpoint_key = future_to_endpoint[future]
            try:
                result = future.result()
                if result:
                    fetched_data[endpoint_key] = result
            except Exception as e:
                logger.error("❌ Exception retrieving result for endpoint '%s': %s", endpoint_key, str(e), exc_info=True)

    # --- Stage 3: Process Concurrently Fetched Data ---
    if 'jobs' in fetched_data:
        try:
            job_analysis = analyzer.analyze_job_performance(fetched_data['jobs'])
            add_and_append(job_analysis, 'jobs')
        except Exception as e:
            logger.error("❌ Failed during 'jobs' processing for attempt %s: %s", attempt_id, str(e), exc_info=True)
            _record_error('jobs', analyzer.analyze_job_performance, e, endpoint_errors)

    if 'executors' in fetched_data:
        try:
            executor_analysis = analyzer.analyze_executor_performance(fetched_data['executors'])
            add_and_append(executor_analysis, 'executors')
        except Exception as e:
            logger.error("❌ Failed during 'executors' processing for attempt %s: %s", attempt_id, str(e), exc_info=True)
            _record_error('executors', analyzer.analyze_executor_performance, e, endpoint_errors)

    if 'sql' in fetched_data:
        try:
            sql_analysis = analyzer.analyze_sql_queries(fetched_data['sql'])
            add_and_append(sql_analysis, 'sql_queries')
        except Exception as e:
            logger.error("❌ Failed during 'sql' processing for attempt %s: %s", attempt_id, str(e), exc_info=True)
            _record_error('sql', analyzer.analyze_sql_queries, e, endpoint_errors)

    # --- Stage 4: Process Stages and Concurrently Fetch Sub-tasks ---
    if 'stages' in fetched_data:
        stages_data = fetched_data['stages']
        try:
            stages_analysis = analyzer.analyze_stage_performance(stages_data)
            add_and_append(stages_analysis, 'stages')
        except Exception as e:
            logger.error("❌ Failed during 'stages' post-processing for attempt %s: %s", attempt_id, str(e), exc_info=True)
            _record_error('stages', analyzer.analyze_stage_performance, e, endpoint_errors)

        # Concurrently fetch tasks and summaries for all stages
        with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="SHS_SubTask") as executor:
            future_to_stage_task = {}
            for stage_raw in stages_data:
                stage_id = stage_raw.get('stageId')
                stage_attempt_id = stage_raw.get('attemptId', 0)
                if stage_id is None:
                    continue
                # Submit tasks and task_summaries calls for each stage
                future_to_stage_task[executor.submit(safe_call, 'tasks', shs_client.get_stage_tasks, app_id, attempt_id, stage_id, stage_attempt_id)] = ('tasks', stage_id, stage_attempt_id)
                if TASK_BINARY == 'true':
                    future_to_stage_task[executor.submit(safe_call, 'task_summaries', shs_client.get_stage_task_summary, app_id, attempt_id, stage_id, stage_attempt_id)] = ('task_summaries', stage_id, stage_attempt_id)
            
            for future in concurrent.futures.as_completed(future_to_stage_task):
                endpoint_key, stage_id, stage_attempt_id = future_to_stage_task[future]
                try:
                    result = future.result()
                    if not result:
                        continue
                    
                    if endpoint_key == 'tasks':
                        for task in result:
                            task['stage_id'] = stage_id
                            task['stage_attempt_id'] = stage_attempt_id
                        task_analysis = analyzer.analyze_task_performance(result)
                        add_and_append(task_analysis, 'tasks')
                    elif endpoint_key == 'task_summaries':
                        result['stage_id'] = stage_id
                        result['stage_attempt_id'] = stage_attempt_id
                        summary_analysis = analyzer.analyze_task_summaries([result])
                        add_and_append(summary_analysis, 'task_summaries')

                except Exception as e:
                    logger.error("❌ Exception retrieving sub-task result for endpoint '%s' stage '%s': %s", endpoint_key, stage_id, str(e), exc_info=True)
    else:
        logger.warning("⚠️ No stage data returned from API for attempt %s. Skipping task and summary collection.", attempt_id)
    
    return app_results

def analyze_application_attempts(
    app_id: str,
    shs_client: Any,
    analyzer: Any,
    cluster_id: str,
    cluster_name: str,
    endpoint_errors: Dict[str, List[Dict[str, Any]]],
    endpoint_attempted: Dict[str, bool],
    endpoint_skipped_reasons: Dict[str, List[str]]
) -> List[Dict[str, Any]]:
    """
    Orchestrates the analysis of a single application by finding all successful attempts
    and processing each one individually.
    """
    all_attempt_results = []
    
    # This is a safecall wrapper for use inside this orchestrator function
    def orchestrator_safe_call(endpoint_key: str, fn, *args, **kwargs):
        endpoint_attempted[endpoint_key] = True
        try:
            return fn(*args, **kwargs)
        except Exception as e:
            logger.error("API call for endpoint '%s' app '%s' failed: %s", endpoint_key, app_id, str(e), exc_info=True)
            _record_error(endpoint_key, fn, e, endpoint_errors)
            return None

    app_details = orchestrator_safe_call('applications', shs_client.get_application_details, app_id)

    if not app_details or not app_details.get('attempts'):
        logger.warning("Application %s has no details or attempts data. Skipping.", app_id)
        return []

    attempts = app_details.get('attempts', [])
    successful_attempts_og = [attempt for attempt in attempts if attempt.get('completed', False)]
    successful_attempts = [attempt for attempt in attempts]
    
    logger.info("Found %d total attempts for app %s. Will analyze %d successful attempts.", len(attempts), app_id, len(successful_attempts))
    if not successful_attempts:
        return []

    for attempt in successful_attempts:
        attempt_id = attempt.get('attemptId', 42)
        #print(f"!!!!! {attempt_id}")
        if not attempt_id:
            continue
        
        attempt_result = process_single_application(
            app_id=app_id,
            attempt_id=attempt_id,
            shs_client=shs_client,
            analyzer=analyzer,
            cluster_id=cluster_id,
            cluster_name=cluster_name,
            endpoint_errors=endpoint_errors,
            endpoint_attempted=endpoint_attempted,
            endpoint_skipped_reasons=endpoint_skipped_reasons,
            app_details=app_details
        )
        if attempt_result:
            all_attempt_results.append(attempt_result)
            
    return all_attempt_results

def analyze_single_cluster(
    cluster_info: Dict,
    timeout_seconds: int,
    max_applications: int,
    spark_session: SparkSession
) -> Dict[str, Any]:
    """Analyzes a single EMR cluster. This version now uses an orchestrator to analyze all successful attempts for each application."""
    cluster_id = cluster_info['cluster_id']
    cluster_name = cluster_info['cluster_name']
    spark_context_id = cluster_info['spark_context_id']
    logger.info("Starting analysis for cluster: %s (%s)", cluster_name, cluster_id)

    results_aggregator = {
        'cluster_id': cluster_id, 'cluster_name': cluster_name, 'spark_context_id': spark_context_id, 'status': cluster_info.get('status', 'UNKNOWN'),
        'normalized_instance_hours': cluster_info.get('normalized_instance_hours', 0),
        'applications': [], 'jobs': [], 'stages': [], 'tasks': [], 'sql_queries': [], 'executors': [], 'task_summaries': [],
        'analysis_status': 'PENDING', 'error_message': ''
    }
    
    tracked_endpoints = ['applications', 'jobs', 'stages', 'tasks', 'sql', 'executors', 'task_summaries']
    endpoint_errors: Dict[str, List[Dict[str, Any]]] = {k: [] for k in tracked_endpoints}
    endpoint_attempted: Dict[str, bool] = {k: False for k in tracked_endpoints}
    endpoint_skipped_reasons: Dict[str, List[str]] = {k: [] for k in tracked_endpoints}
    
    # This is a safecall wrapper for use inside this orchestrator function
    def cluster_safe_call(endpoint_key: str, fn, *args, **kwargs):
        try:
            # Circuit breaker logic (if implemented in client)
            if hasattr(shs_client, 'should_skip') and shs_client.should_skip(endpoint_key):
                logger.warning("Skipping endpoint %s due to prior failures", endpoint_key)
                endpoint_skipped_reasons.setdefault(endpoint_key, []).append("Skipped due to prior failures (circuit breaker)")
                return None
        except UnboundLocalError:
            pass  # shs_client not yet defined
            
        endpoint_attempted[endpoint_key] = True
        try:
            result = fn(*args, **kwargs)
            try:
                if hasattr(shs_client, 'record_success'):
                    shs_client.record_success(endpoint_key)
            except Exception: pass
            return result
        except Exception as e:
            logger.error("Endpoint '%s' call failed: %s", endpoint_key, str(e), exc_info=True)
            _record_error(endpoint_key, fn, e, endpoint_errors)
            try:
                if hasattr(shs_client, 'record_failure'):
                    shs_client.record_failure(endpoint_key)
            except Exception: pass
            return None

    try:
        if spark_context_id is None:
          cow = 'moo'
        else:
            dataplane_workspace_url = dbutils.secrets.get(scope="shscreds", key="dpurl")
            base_url = f'{dataplane_workspace_url}/sparkui/{cluster_id}/driver-{spark_context_id}/'

            cookies = {'DATAPLANE_DOMAIN_DBAUTH': dbutils.secrets.get(scope="shscreds", key="cookies"), 'PRIMARY_DOMAIN':'workspace'}

            try:
                session = requests.Session()
                response = session.get(base_url, cookies=cookies, allow_redirects=True)
                response.raise_for_status()
                logger.info("✅ HTTP session established successfully (Status: %s)", response.status_code)
            except requests.exceptions.RequestException as e:
                logger.error("❌ Failed to establish HTTP session: %s", str(e), exc_info=True)
                raise

            shs_client = SparkHistoryServerClient(base_url, session, cookies)
            analyzer = SparkMetricsAnalyzer(spark_session)
            
            applications = cluster_safe_call('applications', shs_client.get_applications, limit=max_applications)

            if not applications:
                logger.warning("No applications found in Spark History Server for %s", cluster_name)
                results_aggregator['analysis_status'] = 'NO_APPLICATIONS(No data)'
                results_aggregator['endpoint_errors'] = endpoint_errors
                results_aggregator['endpoint_attempted'] = endpoint_attempted
                results_aggregator['endpoint_skipped_reasons'] = endpoint_skipped_reasons
                return results_aggregator

            logger.info("Found %s applications to analyze in %s", len(applications), cluster_name)
            for app in applications:
                app_id = app.get('id')
                if not app_id:
                    continue
                
                attempt_results_list = analyze_application_attempts(
                    app_id, shs_client, analyzer, cluster_id, cluster_name,
                    endpoint_errors, endpoint_attempted, endpoint_skipped_reasons
                )
                
                for app_data in attempt_results_list:
                    if app_data:
                        for key in ['applications', 'jobs', 'stages', 'tasks', 'sql_queries', 'executors', 'task_summaries']:
                            results_aggregator[key].extend(app_data.get(key, []))

            # Finalize status
            if not any(results_aggregator.get(key) for key in ['applications', 'jobs', 'stages', 'tasks', 'sql_queries', 'executors', 'task_summaries']):
                results_aggregator['analysis_status'] = _compose_failed_status(endpoint_errors, endpoint_attempted, results_aggregator.get('error_message'))
                if not results_aggregator['error_message']:
                    results_aggregator['error_message'] = "No data returned from SHS endpoints for any successful attempts"
                logger.warning("Cluster %s analysis produced no data across all endpoints.", cluster_name)
            else:
                results_aggregator['analysis_status'] = 'COMPLETED'
            
            logger.info("Completed analysis for cluster: %s", cluster_name)
    except Exception as e:
        logger.error("Failed to analyze cluster %s: %s", cluster_name, str(e), exc_info=True)
        results_aggregator['analysis_status'] = f"FAILED({_classify_exception_code(e)})"
        results_aggregator['error_message'] = str(e)

    results_aggregator['endpoint_errors'] = endpoint_errors
    results_aggregator['endpoint_attempted'] = endpoint_attempted
    results_aggregator['endpoint_skipped_reasons'] = endpoint_skipped_reasons
    return results_aggregator

def process_clusters_in_batches(
    clusters_to_analyze: List[Dict],
    batch_size: int,
    batch_delay_seconds: int,
    spark_session: Any
) -> Tuple[List[Dict], List[Dict], List[Dict], List[Dict], List[Dict], List[Dict], List[Dict], List[Dict]]:
    """Process clusters in sequential batches to manage resources and API limits. This version builds a JSON-as-string status_details per cluster with endpoint-level OK/FAILED/SKIPPED."""
    all_results = {'applications': [], 'jobs': [], 'stages': [], 'tasks': [], 'sql_queries': [], 'executors': [], 'task_summaries': []}
    cluster_summaries = []

    def endpoint_status_summary(c: Dict[str, Any]) -> str:
        key_map = {'applications': 'applications', 'jobs': 'jobs', 'stages': 'stages', 'tasks': 'tasks', 'sql': 'sql_queries', 'executors': 'executors', 'task_summaries': 'task_summaries'}
        attempted = c.get('endpoint_attempted', {})
        errors = c.get('endpoint_errors', {})
        skipped = c.get('endpoint_skipped_reasons', {})
        status_obj = {}
        for ep, data_key in key_map.items():
            has_data = len(c.get(data_key, [])) > 0
            ep_attempted = bool(attempted.get(ep, False))
            ep_errors = errors.get(ep, [])
            ep_skips = skipped.get(ep, [])
            if has_data:
                status_obj[ep] = 'OK'
            elif ep_attempted and len(ep_errors) > 0:
                # Extract just the error codes from the error objects
                error_codes = [err.get('code', 'UNKNOWN') for err in ep_errors[:MAX_ENDPOINT_FAILURES]]
                status_obj[ep] = f"FAILED({','.join(error_codes)})"
            else:
                # Use a simplified reason
                reason = ep_skips[0] if ep_skips else "No data"
                # Truncate long reasons to keep it concise
                if len(reason) > 30:
                    reason = reason[:27] + "..."
                status_obj[ep] = f"SKIPPED({reason})"
        return json.dumps(status_obj, separators=(',', ':'))

    total_batches = (len(clusters_to_analyze) + batch_size - 1) // batch_size
    for i in range(0, len(clusters_to_analyze), batch_size):
        batch_clusters = clusters_to_analyze[i:i + batch_size]
        current_batch_num = (i // batch_size) + 1
        logger.info("Processing batch %d/%d...", current_batch_num, total_batches)

        with concurrent.futures.ThreadPoolExecutor(max_workers=len(batch_clusters)) as executor:
            future_to_cluster = {executor.submit(analyze_single_cluster, c_info, TIMEOUT_SECONDS, MAX_APPLICATIONS, spark_session): c_info for c_info in batch_clusters}
            
            for future in concurrent.futures.as_completed(future_to_cluster):
                cluster_info = future_to_cluster[future]
                try:
                    cluster_results = future.result()
                    if cluster_results:
                        for key in all_results.keys():
                            all_results[key].extend(cluster_results.get(key, []))
                        
                        cluster_summaries.append({
                            'cluster_id': cluster_results['cluster_id'],
                            'cluster_name': cluster_results['cluster_name'],
                            'status': cluster_results['status'],
                            'spark_context_id': cluster_results['spark_context_id'],
                            'analysis_status': cluster_results['analysis_status'],
                            'status_details': endpoint_status_summary(cluster_results),
                            'normalized_instance_hours': cluster_results.get('normalized_instance_hours', 0),
                            'total_applications': len(cluster_results['applications']),
                            'total_jobs': len(cluster_results['jobs']),
                            'total_stages': len(cluster_results['stages']),
                            'total_tasks': len(cluster_results['tasks']),
                            'total_sql_queries': len(cluster_results['sql_queries']),
                            'total_executors': len(cluster_results['executors']),
                            'total_task_summaries': len(cluster_results['task_summaries']),
                        })
                except Exception as e:
                    logger.error("Error processing results for cluster %s: %s", cluster_info['cluster_id'], str(e), exc_info=True)
                    cluster_summaries.append({
                        'cluster_id': cluster_info['cluster_id'],
                        'cluster_name': cluster_info['cluster_name'],
                        'spark_context_id': cluster_info['spark_context_id'],
                        'status': 'FAILED_PROCESSING',
                        'analysis_status': f"FAILED({_classify_exception_code(e)})",
                        'status_details': json.dumps({'error': _truncate_message(str(e))}),
                        'normalized_instance_hours': cluster_info.get('normalized_instance_hours', 0),
                        'total_applications': 0, 'total_jobs': 0, 'total_stages': 0,
                        'total_tasks': 0, 'total_sql_queries': 0, 'total_executors': 0, 'total_task_summaries': 0
                    })

        if current_batch_num < total_batches:
            logger.info("Waiting %d seconds between batches...", batch_delay_seconds)
            time.sleep(batch_delay_seconds)
            
    return all_results['applications'], all_results['jobs'], all_results['stages'], all_results['tasks'], all_results['sql_queries'], all_results['executors'], all_results['task_summaries'], cluster_summaries


### Main Execution

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from typing import Dict, Any
import time
import logging

logger = logging.getLogger(__name__)

def main_analysis():
    """Main function to drive the EMR cluster analysis."""
    try:
        # Step 1: Discover Clusters
        dbx_discovery = DBXClusterDiscovery()
        total_clusters_discovered = 0
        logger.info("Discovering DBX clusters based on specified criteria...")
        discovered_clusters = dbx_discovery.discover_clusters(
            max_clusters=MAX_CLUSTERS
        )
        total_clusters_discovered = len(discovered_clusters)
            
        
        logger.info("Will analyze %s clusters.", total_clusters_discovered)

        # Step 3: Process clusters in batches
        all_apps, all_jobs, all_stages, all_tasks, all_sql, all_execs, all_task_sums, summaries = process_clusters_in_batches(
            discovered_clusters, BATCH_SIZE, BATCH_DELAY_SECONDS, spark
        )

        # Step 4: Create DataFrames
        logger.info("Creating analysis DataFrames...")
        analyzer = SparkMetricsAnalyzer(spark)
        apps_df, jobs_df, stages_df, tasks_df, sql_df, exec_df, task_sum_df = analyzer.create_dynamic_dataframes(
            all_apps, all_jobs, all_stages, all_tasks, all_sql, all_execs, all_task_sums
        )

        cluster_summary_df = None
        if summaries:
            summary_schema = StructType([
                StructField("cluster_id", StringType(), True), StructField("cluster_name", StringType(), True), StructField("spark_context_id", StringType(), True),
                StructField("status", StringType(), True), StructField("analysis_status", StringType(), True),
                StructField("status_details", StringType(), True),
                StructField("normalized_instance_hours", IntegerType(), True),
                StructField("total_applications", IntegerType(), True), StructField("total_jobs", IntegerType(), True),
                StructField("total_stages", IntegerType(), True), StructField("total_tasks", IntegerType(), True),
                StructField("total_sql_queries", IntegerType(), True), StructField("total_executors", IntegerType(), True),
                StructField("total_task_summaries", IntegerType(), True)
            ])
            cluster_summary_df = spark.createDataFrame(summaries, schema=summary_schema)

        # Step 5: Final Summary
        total_clusters_analyzed = len([c for c in (summaries or []) if c.get('analysis_status') == 'COMPLETED'])
        
        # Calculate clusters_fully_analyzed - clusters with data from all major endpoints
        clusters_fully_analyzed = 0
        if summaries:
            for cluster_summary in summaries:
                # A cluster is considered fully analyzed if it has data from all major endpoints
                has_all_data = (
                    cluster_summary.get('total_applications', 0) > 0 and
                    cluster_summary.get('total_jobs', 0) > 0 and
                    cluster_summary.get('total_stages', 0) > 0 and
                    cluster_summary.get('total_tasks', 0) > 0 and
                    cluster_summary.get('total_sql_queries', 0) > 0 and
                    cluster_summary.get('total_executors', 0) > 0 and
                    cluster_summary.get('total_task_summaries', 0) > 0
                )
                if has_all_data:
                    clusters_fully_analyzed += 1
        
        # Calculate success rate
        cluster_analysis_success_rate = 0.0
        if total_clusters_analyzed > 0:
            cluster_analysis_success_rate = round((clusters_fully_analyzed / total_clusters_analyzed) * 100.0, 2)
        
        final_summary = {
            'clusters_discovered_count': total_clusters_discovered,
            'clusters_extracted_count': total_clusters_analyzed,
            'clusters_fully_analyzed_count': clusters_fully_analyzed,
            'cluster_analysis_success_rate %': cluster_analysis_success_rate,
            'total_applications': len(all_apps),
            'total_jobs': len(all_jobs),
            'total_stages': len(all_stages),
            'total_tasks': len(all_tasks),
            'total_sql_queries': len(all_sql),
            'total_executors': len(all_execs),
            'total_task_summaries': len(all_task_sums)
        }

        return {
            "cluster_summaries_df": cluster_summary_df,
            "applications_df": apps_df,
            "jobs_df": jobs_df,
            "stages_df": stages_df,
            "tasks_df": tasks_df,
            "sql_df": sql_df,
            "executors_df": exec_df,
            "task_summaries_df": task_sum_df,
            "summary": final_summary
        }
    except Exception as e:
        logger.error("Main analysis failed: %s", str(e), exc_info=True)
        raise

# Execute main analysis
results = main_analysis()

# Display summary and make DataFrames available
if results:
    cluster_summaries_df = results.get('cluster_summaries_df')
    applications_df = results.get('applications_df')
    jobs_df = results.get('jobs_df')
    stages_df = results.get('stages_df')
    tasks_df = results.get('tasks_df')
    sql_df = results.get('sql_df')
    executors_df = results.get('executors_df')
    task_summaries_df = results.get('task_summaries_df')
    analysis_summary = results.get('summary', {})
    
    print("-" * 100)
    print("EMR SPARK HISTORY ANALYSIS COMPLETED!")
    print("-" * 100)
    for key, value in analysis_summary.items():
        print(f"  {key.replace('_', ' ').title()}: {value}")
    print("\nDataFrames available for analysis: cluster_summaries_df, applications_df, jobs_df, stages_df, tasks_df, sql_df, executors_df, task_summaries_df")

    # Write outputs as managed Delta tables in the configured catalog/schema
    if ENVIRONMENT == "prod":
        logger.info("Writing analysis results to Delta tables in %s.%s", CATALOG_NAME, SCHEMA_NAME)
        try:
            run_ts = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
            for df_name, df_instance in results.items():
                if df_name.endswith('_df') and df_instance:
                    table_name = df_name.replace('_df', '')
                    versioned_table_name = f"{table_name}_{run_ts}"
                    full_table_name = f"{CATALOG_NAME}.{SCHEMA_NAME}.{versioned_table_name}"
                    (
                        df_instance.write
                        .format("delta")
                        .mode("errorifexists")
                        .option("overwriteSchema", "true")
                        .saveAsTable(full_table_name)
                    )
            logger.info("All analysis results successfully written to Delta tables.")
        except Exception as e:
            print(f"\n❌ ANALYSIS FAILED: {str(e)} - check the logs above for detailed error information.")

for handler in logging.root.handlers[:]:
    handler.flush()
    handler.close()

### Table Exploration

In [0]:
# The analysis_summary dictionary contains the final counts.
if 'analysis_summary' in locals() and analysis_summary:
    display(spark.createDataFrame([analysis_summary]))
else:
    print("Analysis summary is not available.")

In [0]:
if 'cluster_summaries_df' in locals() and cluster_summaries_df:
    display(cluster_summaries_df)
else:
    print("Cluster summaries DataFrame is not available.")

In [0]:
if 'applications_df' in locals() and applications_df:
    display(applications_df)
else:
    print("Applications DataFrame is not available.")

In [0]:
if 'executors_df' in locals() and executors_df:
    display(executors_df)
else:
    print("Executors DataFrame is not available.")

In [0]:
if 'jobs_df' in locals() and jobs_df:
    display(jobs_df)
else:
    print("Jobs DataFrame is not available.")

In [0]:
if 'stages_df' in locals() and stages_df:
    display(stages_df)
else:
    print("Stages DataFrame is not available.")

In [0]:
if 'tasks_df' in locals() and tasks_df:
    display(tasks_df)
else:
    print("Tasks DataFrame is not available.")

In [0]:
if 'task_summaries_df' in locals() and task_summaries_df:
    display(task_summaries_df)
else:
    print("Task summaries DataFrame is not available.")

In [0]:
if 'sql_df' in locals() and sql_df:
    display(sql_df)
else:
    print("SQL queries DataFrame is not available.")