In [0]:
# Databricks / PySpark Diagnostic Notebook for Streaming Tuning
from pyspark.sql.streaming import StreamingQueryListener
import json
from datetime import datetime
from pyspark.sql.functions import current_timestamp
import time
import threading
import matplotlib.pyplot as plt

In [0]:
# 1. Show last progress snapshot for a given query
def show_query_progress(query):
    progress = query.lastProgress
    print("Last Progress Snapshot:")
    print(json.dumps(progress, indent=2))

# 2. Display recommended tuning hints based on batch duration
def diagnose_performance(query, interval=60):
    progress = query.lastProgress
    if not progress:
        print("No progress yet. Query may not have started.")
        return

    try:
        batch_time = int(progress['durationMs']['triggerExecution']) / 1000
        trigger_interval = interval  # Change based on your actual trigger

        print(f"\nTrigger Interval: {trigger_interval} seconds")
        print(f"Batch Duration: {batch_time:.2f} seconds")

        if batch_time > trigger_interval:
            print("\n[DIAGNOSIS] Batch is slower than trigger interval.")
            print("- Tune transformations to reduce load.")
            print("- Reduce input rate (e.g., maxOffsetsPerTrigger).")
            print("- Consider scaling up cluster cores or executors.")
        elif batch_time < trigger_interval * 0.5:
            print("\n[DIAGNOSIS] Batch is much faster than trigger.")
            print("- You may be underutilizing resources.")
            print("- Consider reducing trigger interval for lower latency.")
        else:
            print("\n[DIAGNOSIS] Trigger and batch are balanced.")
            print("- Good balance. Monitor input growth and state size.")

    except Exception as e:
        print("Could not extract performance info:", str(e))

# 3. Optional: Attach a custom listener for logging
class QueryLogger(StreamingQueryListener):
    def onQueryProgress(self, event):
        progress = event.progress
        print(f"[Batch {progress['batchId']}] {progress['timestamp']} | Rows: {progress['numInputRows']} | Duration: {progress['durationMs']['triggerExecution']}ms")

# 4. List all active streaming queries
def list_active_queries():
    queries = spark.streams.active
    if not queries:
        print("No active streaming queries.")
    else:
        for i, q in enumerate(queries, 1):
            print(f"[{i}] ID: {q.id}, Name: {q.name}, Status: {q.status['message']}")

# 5. Stop all active queries (use with caution)
def stop_all_queries():
    for q in spark.streams.active:
        print(f"Stopping query ID: {q.id}, Name: {q.name}")
        q.stop()

# 6. Simulate staged Delta writes for testing streaming ingestion
def simulate_delta_commits(base_path="/tmp/delta/stream_source", num_batches=50, rows_per_batch=1000, interval_sec=10):
    for i in range(num_batches):
        df = spark.range(i * rows_per_batch, (i + 1) * rows_per_batch).withColumn("ts", current_timestamp())
        df.write.format("delta").mode("append").save(base_path)
        print(f"Committed {rows_per_batch} rows to Delta at batch {i}")
        time.sleep(interval_sec)

# 6b. Run simulate_delta_commits in the background thread
def start_background_commits(base_path="/tmp/delta/stream_source", num_batches=50, rows_per_batch=1000, interval_sec=20):
    thread = threading.Thread(
        target=simulate_delta_commits,
        args=(base_path, num_batches, rows_per_batch, interval_sec),
        daemon=True
    )
    thread.start()
    print("Simulated Delta commit thread started.")
    return thread

# 7. Attach a streaming query to the simulated Delta source
def attach_read_stream(base_path="/tmp/delta/stream_source", proc_time="5 seconds"):
    df = spark.readStream.format("delta").load(base_path)
    query = (
        df.writeStream
          .format("console")
          .outputMode("append")
          .option("checkpointLocation", "/tmp/delta/checkpoints/stream_test")
          .trigger(processingTime=proc_time)
          .start()
    )
    print("Streaming query started.")
    return df, query

# 8. Check partition count relative to cluster cores for batch DataFrames only
def check_partitions_and_cores_static(df):
    if df.isStreaming:
        print("⚠️ Cannot check partitions on a streaming DataFrame. Try this after writing to Delta or converting to batch.")
        return
    num_partitions = df.rdd.getNumPartitions()
    total_cores = spark.sparkContext.defaultParallelism
    print(f"DataFrame partitions: {num_partitions}")
    print(f"Available cluster cores: {total_cores}")
    if num_partitions < total_cores:
        print("⚠️ Under-partitioned: not all cores will be used.")
    elif num_partitions > total_cores * 4:
        print("⚠️ Over-partitioned: could lead to overhead and small files.")
    else:
        print("✅ Partitioning is well balanced with available cores.")

# 9. Visualize task durations across batches for a query

def plot_task_durations(query):
    from IPython.display import display
    durations = []
    timestamps = []

    progress_history = query.recentProgress
    if not progress_history:
        print("No progress history available.")
        return

    for p in progress_history:
        try:
            durations.append(int(p['durationMs']['triggerExecution']) / 1000)
            timestamps.append(p['timestamp'])
        except Exception:
            continue

    if durations:
        plt.figure(figsize=(10, 4))
        plt.plot(timestamps, durations, marker='o', color='dodgerblue')
        plt.xticks(rotation=45)
        plt.ylabel("Trigger Duration (sec)")
        plt.title("Micro-Batch Execution Time over Time")
        plt.grid(True)
        plt.tight_layout()
        plt.show()
    else:
        print("No valid durations to plot.")

In [0]:
df, query = attach_read_stream("/tmp/delta/stream_source", "10 seconds")

In [0]:
start_background_commits("/tmp/delta/stream_source", 50, 2000, 10)

In [0]:
check_partitions_and_cores_static(spark.read.format("delta").load("/tmp/delta/stream_source"))

In [0]:
plot_task_durations(query)

In [0]:
diagnose_performance(query,10)

In [0]:
query.stop()