# Compare kernel statistics

This notebook compares kernel statistics between two Nsight Systems SQLite report files.

We'll specifically compare the `set_prognostic_edmf_precomputed_quantities_precipitation` kernel between:
- `baseline.sqlite` - baseline run
- `mod.sqlite` - modified run

## Imports and setup

In [1]:
import difflib
import re
import sqlite3
import unicodedata
from pathlib import Path

import pandas as pd
from IPython.display import display

## Helper functions

In [2]:
def get_kernel_stats(db_path, kernel_name_pattern):
    """
    Extract kernel statistics from an nsys SQLite database.

    Args:
        db_path: Path to the SQLite database file
        kernel_name_pattern: SQL LIKE pattern to match kernel names

    Returns:
        Dictionary with aggregated statistics and list of all kernel invocations
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Query to get all kernel invocations with the specified name
    query = """
    SELECT
        s.value as kernelName,
        k.start,
        k.end,
        (k.end - k.start) as duration_ns,
        k.gridX,
        k.gridY,
        k.gridZ,
        k.blockX,
        k.blockY,
        k.blockZ,
        k.registersPerThread,
        k.staticSharedMemory,
        k.dynamicSharedMemory,
        k.localMemoryPerThread,
        k.localMemoryTotal,
        k.deviceId,
        k.streamId,
        k.launchType,
        k.sharedMemoryExecuted,
        k.correlationId,
        k.globalPid
    FROM CUPTI_ACTIVITY_KIND_KERNEL k
    JOIN StringIds s ON k.demangledName = s.id
    WHERE s.value LIKE ?
    ORDER BY k.start
    """

    cursor.execute(query, (kernel_name_pattern,))
    rows = cursor.fetchall()

    if not rows:
        conn.close()
        return None, None

    # Calculate statistics
    durations = [row[3] for row in rows]

    stats = {
        "kernel_name": rows[0][0],
        "invocation_count": len(rows),
        "duration_ns": {
            "total": sum(durations),
            "mean": sum(durations) / len(durations),
            "min": min(durations),
            "max": max(durations),
            "median": sorted(durations)[len(durations) // 2],
        },
        "grid_dims": {
            "x": rows[0][4],
            "y": rows[0][5],
            "z": rows[0][6],
        },
        "block_dims": {
            "x": rows[0][7],
            "y": rows[0][8],
            "z": rows[0][9],
        },
        "registers_per_thread": rows[0][10],
        "static_shared_memory": rows[0][11],
        "dynamic_shared_memory": rows[0][12],
        "local_memory_per_thread": rows[0][13],
        "local_memory_total": rows[0][14],
        "device_id": rows[0][15],
        "stream_id": rows[0][16],
        "launch_type": rows[0][17],
        "shared_memory_executed": rows[0][18],
        "correlation_id": rows[0][19],
        "global_pid": rows[0][20],
    }

    # Calculate theoretical occupancy
    # Occupancy = (active warps per SM) / (max warps per SM) * 100%
    block_size = rows[0][7] * rows[0][8] * rows[0][9]
    warps_per_block = (block_size + 31) // 32
    max_warps_per_sm = 48

    # Register file size: ~49KB per SM (determined by occupancy calculator matching nsys)
    max_registers_per_sm = 49152  # registers available per SM
    registers_per_block = rows[0][10] * block_size
    blocks_limited_by_registers = (
        max(1, max_registers_per_sm // registers_per_block)
        if registers_per_block > 0
        else 8
    )

    # Shared memory limit (if applicable, but usually not limiting for this kernel)
    max_shared_mem_per_sm = 96000
    shared_mem_per_block = rows[0][11] + rows[0][12]
    blocks_limited_by_shared_mem = (
        max(1, max_shared_mem_per_sm // shared_mem_per_block)
        if shared_mem_per_block > 0
        else 8
    )

    # Effective blocks per SM is the minimum of register and shared memory constraints
    blocks_per_sm = min(
        blocks_limited_by_registers, blocks_limited_by_shared_mem
    )

    # Calculate occupancy
    active_warps = warps_per_block * blocks_per_sm
    occupancy = (active_warps / max_warps_per_sm) * 100
    stats["theoretical_occupancy"] = occupancy
    conn.close()
    return stats, rows


def get_launch_type_name(launch_type_id):
    """Convert launch type ID to human-readable name."""
    launch_types = {
        0: "Regular",
        1: "Cooperative",
    }
    return launch_types.get(launch_type_id, f"Unknown ({launch_type_id})")


In [3]:
def format_duration(ns):
    """Convert nanoseconds to human-readable format."""
    if ns < 1000:
        return f"{ns:.2f} ns"
    elif ns < 1_000_000:
        return f"{ns / 1000:.2f} µs"
    elif ns < 1_000_000_000:
        return f"{ns / 1_000_000:.2f} ms"
    else:
        return f"{ns / 1_000_000_000:.2f} s"

In [4]:
def compare_stats(baseline_stats, mod_stats):
    """
    Compare two sets of kernel statistics and display results in a single table.

    Args:
        baseline_stats: Statistics from baseline run
        mod_stats: Statistics from modified run
    """
    if baseline_stats is None or mod_stats is None:
        print(
            "Error: Could not retrieve statistics from one or both databases"
        )
        return

    print("\n" + "=" * 140)
    print(f"KERNEL: {baseline_stats['kernel_name']}")
    print("=" * 140 + "\n")

    # Build comprehensive data structure with all metrics
    data = {"Metric": [], "Baseline": [], "Modified": [], "% Change": []}

    # Timing Statistics
    metrics = [
        (
            "Invocations",
            baseline_stats["invocation_count"],
            mod_stats["invocation_count"],
            True,
        ),
        (
            "Total Duration",
            baseline_stats["duration_ns"]["total"],
            mod_stats["duration_ns"]["total"],
            False,
        ),
        (
            "Mean Duration",
            baseline_stats["duration_ns"]["mean"],
            mod_stats["duration_ns"]["mean"],
            False,
        ),
        (
            "Median Duration",
            baseline_stats["duration_ns"]["median"],
            mod_stats["duration_ns"]["median"],
            False,
        ),
        (
            "Min Duration",
            baseline_stats["duration_ns"]["min"],
            mod_stats["duration_ns"]["min"],
            False,
        ),
        (
            "Max Duration",
            baseline_stats["duration_ns"]["max"],
            mod_stats["duration_ns"]["max"],
            False,
        ),
    ]

    for metric_name, baseline_val, mod_val, is_count in metrics:
        data["Metric"].append(metric_name)

        if is_count:
            data["Baseline"].append(f"{baseline_val:,}")
            data["Modified"].append(f"{mod_val:,}")
            pct = (
                (mod_val - baseline_val) / baseline_val * 100
                if baseline_val > 0
                else 0
            )
        else:
            data["Baseline"].append(format_duration(baseline_val))
            data["Modified"].append(format_duration(mod_val))
            pct = (mod_val - baseline_val) / baseline_val * 100

        sign = "+" if pct > 0 else ""
        data["% Change"].append(f"{sign}{pct:.2f}%")

    # Kernel Dimensions & Launch
    data["Metric"].append("Grid Dimensions")
    data["Baseline"].append(
        f"<{baseline_stats['grid_dims']['x']}, {baseline_stats['grid_dims']['y']}, {baseline_stats['grid_dims']['z']}>"
    )
    data["Modified"].append(
        f"<{mod_stats['grid_dims']['x']}, {mod_stats['grid_dims']['y']}, {mod_stats['grid_dims']['z']}>"
    )
    data["% Change"].append("—")

    data["Metric"].append("Block Dimensions")
    data["Baseline"].append(
        f"<{baseline_stats['block_dims']['x']}, {baseline_stats['block_dims']['y']}, {baseline_stats['block_dims']['z']}>"
    )
    data["Modified"].append(
        f"<{mod_stats['block_dims']['x']}, {mod_stats['block_dims']['y']}, {mod_stats['block_dims']['z']}>"
    )
    data["% Change"].append("—")

    block_size_baseline = (
        baseline_stats["block_dims"]["x"]
        * baseline_stats["block_dims"]["y"]
        * baseline_stats["block_dims"]["z"]
    )
    block_size_mod = (
        mod_stats["block_dims"]["x"]
        * mod_stats["block_dims"]["y"]
        * mod_stats["block_dims"]["z"]
    )
    data["Metric"].append("Block Size (threads)")
    data["Baseline"].append(f"{block_size_baseline:,}")
    data["Modified"].append(f"{block_size_mod:,}")
    data["% Change"].append("—")

    grid_size_baseline = (
        baseline_stats["grid_dims"]["x"]
        * baseline_stats["grid_dims"]["y"]
        * baseline_stats["grid_dims"]["z"]
    )
    grid_size_mod = (
        mod_stats["grid_dims"]["x"]
        * mod_stats["grid_dims"]["y"]
        * mod_stats["grid_dims"]["z"]
    )
    data["Metric"].append("Grid Size (blocks)")
    data["Baseline"].append(f"{grid_size_baseline:,}")
    data["Modified"].append(f"{grid_size_mod:,}")
    data["% Change"].append("—")

    total_threads_baseline = grid_size_baseline * block_size_baseline
    total_threads_mod = grid_size_mod * block_size_mod
    data["Metric"].append("Total Threads")
    data["Baseline"].append(f"{total_threads_baseline:,}")
    data["Modified"].append(f"{total_threads_mod:,}")
    data["% Change"].append("—")

    data["Metric"].append("Launch Type")
    data["Baseline"].append(
        get_launch_type_name(baseline_stats["launch_type"])
    )
    data["Modified"].append(get_launch_type_name(mod_stats["launch_type"]))
    data["% Change"].append("—")

    # Register Usage
    data["Metric"].append("Registers Per Thread")
    data["Baseline"].append(f"{baseline_stats['registers_per_thread']}")
    data["Modified"].append(f"{mod_stats['registers_per_thread']}")
    if (
        baseline_stats["registers_per_thread"]
        != mod_stats["registers_per_thread"]
    ):
        pct = (
            (
                mod_stats["registers_per_thread"]
                - baseline_stats["registers_per_thread"]
            )
            / baseline_stats["registers_per_thread"]
            * 100
        )
        sign = "+" if pct > 0 else ""
        data["% Change"].append(f"{sign}{pct:.2f}%")
    else:
        data["% Change"].append("—")

    # Memory Usage
    memory_metrics = [
        (
            "Static Shared Memory (bytes)",
            baseline_stats["static_shared_memory"],
            mod_stats["static_shared_memory"],
        ),
        (
            "Dynamic Shared Memory (bytes)",
            baseline_stats["dynamic_shared_memory"],
            mod_stats["dynamic_shared_memory"],
        ),
        (
            "Shared Memory Executed (bytes)",
            baseline_stats["shared_memory_executed"],
            mod_stats["shared_memory_executed"],
        ),
        (
            "Local Memory Per Thread (bytes)",
            baseline_stats["local_memory_per_thread"],
            mod_stats["local_memory_per_thread"],
        ),
        (
            "Local Memory Total (bytes)",
            baseline_stats["local_memory_total"],
            mod_stats["local_memory_total"],
        ),
    ]

    for metric_name, baseline_val, mod_val in memory_metrics:
        data["Metric"].append(metric_name)
        data["Baseline"].append(f"{baseline_val:,}")
        data["Modified"].append(f"{mod_val:,}")

        if baseline_val == mod_val:
            data["% Change"].append("—")
        elif baseline_val == 0:
            data["% Change"].append("—")
        else:
            pct = (mod_val - baseline_val) / baseline_val * 100
            sign = "+" if pct > 0 else ""
            data["% Change"].append(f"{sign}{pct:.2f}%")

    # Theoretical Occupancy
    data["Metric"].append("Theoretical Occupancy (%)")
    data["Baseline"].append(f"{baseline_stats['theoretical_occupancy']:.1f}%")
    data["Modified"].append(f"{mod_stats['theoretical_occupancy']:.1f}%")
    data["% Change"].append("—")

    # Create and display the combined dataframe
    df = pd.DataFrame(data)
    display(df)

    # Performance summary
    print("\n" + "=" * 140)
    baseline_mean = baseline_stats["duration_ns"]["mean"]
    mod_mean = mod_stats["duration_ns"]["mean"]

    if mod_mean < baseline_mean:
        improvement = ((baseline_mean - mod_mean) / baseline_mean) * 100
        print(
            f"✅ Modified version is {improvement:.2f}% FASTER (mean duration)"
        )
    elif mod_mean > baseline_mean:
        regression = ((mod_mean - baseline_mean) / baseline_mean) * 100
        print(
            f"⚠️  Modified version is {regression:.2f}% SLOWER (mean duration)"
        )
    else:
        print("➡️  No change in mean duration")

    print("=" * 140)


## Configuration

In [5]:
# Database paths
baseline_db = Path("../results/nsys/baseline.sqlite")
mod_db = Path("../results/nsys/mod.sqlite")

# Kernel name pattern to match (avoid fragile line numbers)
kernel_pattern = "%set_prognostic_edmf_precomputed_quantities_precipitation%"

# Check if files exist
if not baseline_db.exists():
    print(f"Error: Baseline database not found: {baseline_db}")
else:
    print(f"✓ Found baseline database: {baseline_db}")

if not mod_db.exists():
    print(f"Error: Modified database not found: {mod_db}")
else:
    print(f"✓ Found modified database: {mod_db}")

✓ Found baseline database: ../results/nsys/baseline.sqlite
✓ Found modified database: ../results/nsys/mod.sqlite


## Compare kernel statistics

In [6]:
def normalize_kernel_name(name):
    """Normalize kernel names for robust matching."""
    normalized = unicodedata.normalize("NFKC", name)
    normalized = "".join(ch for ch in normalized if ch.isprintable())
    normalized = re.sub(r"\s+", " ", normalized).strip()
    return normalized


def strip_mod_suffix(name):
    """Remove a loose mod suffix for matching."""
    return re.sub(r"(?:_|\s)mod\b", "", name).strip()


def list_top_kernels(db_path, pattern="%", limit=15):
    """
    List top kernels by total duration, filtered by a name pattern.

    Args:
        db_path: Path to SQLite database
        pattern: SQL LIKE pattern for kernel names
        limit: Number of kernels to return
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    query = """
    SELECT s.value as kernelName, SUM(k.end - k.start) as total_ns, COUNT(*) as count
    FROM CUPTI_ACTIVITY_KIND_KERNEL k
    JOIN StringIds s ON k.demangledName = s.id
    WHERE s.value LIKE ?
    GROUP BY s.value
    ORDER BY total_ns DESC
    LIMIT ?
    """

    cursor.execute(query, (pattern, limit))
    results = cursor.fetchall()
    conn.close()

    return results


def select_kernel_name(baseline_top, mod_top):
    """
    Select the most expensive baseline kernel and the closest name match in modified.
    """
    if not baseline_top:
        raise RuntimeError("No kernels matched in baseline")

    if not mod_top:
        raise RuntimeError("No kernels matched in modified")

    baseline_kernel, baseline_total, _ = baseline_top[0]
    baseline_norm = strip_mod_suffix(normalize_kernel_name(baseline_kernel))

    print(
        f"Selected baseline kernel: {baseline_kernel} (total={format_duration(baseline_total)})"
    )

    best = None
    for mod_kernel, mod_total, _ in mod_top:
        mod_norm = strip_mod_suffix(normalize_kernel_name(mod_kernel))
        score = difflib.SequenceMatcher(None, baseline_norm, mod_norm).ratio()
        candidate = (score, mod_total, mod_kernel)
        if best is None or candidate > best:
            best = candidate

    best_score, best_total, best_kernel = best
    print(
        f"Selected modified kernel: {best_kernel} (total={format_duration(best_total)}, score={best_score:.3f})"
    )

    return baseline_kernel, best_kernel


# Automatically select the hottest kernel from baseline (no hardcoded patterns)
print("Selecting the hottest kernel from baseline...\n")

print("Top 5 kernels in baseline (by total time):")
baseline_top = list_top_kernels(baseline_db, pattern="%", limit=5)
for kernel_name, total_ns, count in baseline_top:
    print(f"  {format_duration(total_ns):>10}  [{count:4d} invocations] {kernel_name}")

print("\nTop 5 kernels in modified (by total time):")
mod_top = list_top_kernels(mod_db, pattern="%", limit=5)
for kernel_name, total_ns, count in mod_top:
    print(f"  {format_duration(total_ns):>10}  [{count:4d} invocations] {kernel_name}")

baseline_kernel_name, mod_kernel_name = select_kernel_name(
    baseline_top, mod_top
)

Selecting the hottest kernel from baseline...

Top 5 kernels in baseline (by total time):
   460.15 ms  [  25 invocations] set_microphysics_tendency_cache__FILE_ClimaAtmos_jl_src_cache_microphysics_cache_jl_L1096
   389.42 ms  [7515 invocations] foreach_FILE_ClimaCore_vVw9t_src_MatrixFields_field_name_set_jl_L81
   283.55 ms  [ 450 invocations] run_field_matrix_solver__FILE_ClimaCore_vVw9t_src_MatrixFields_field_matrix_solver_jl_L352
   249.04 ms  [ 990 invocations] run_field_matrix_solver__FILE_ClimaCore_vVw9t_src_MatrixFields_field_matrix_solver_jl_L354
    87.42 ms  [4185 invocations] regridded_snapshot__FILE_ClimaUtilities_5ClV5_ext_DataHandlingExt_jl_L666

Top 5 kernels in modified (by total time):


   389.90 ms  [7515 invocations] foreach_FILE_ClimaCore_vVw9t_src_MatrixFields_field_name_set_jl_L81
   282.43 ms  [ 450 invocations] run_field_matrix_solver__FILE_ClimaCore_vVw9t_src_MatrixFields_field_matrix_solver_jl_L352
   254.31 ms  [  25 invocations] set_microphysics_tendency_cache__FILE_ClimaAtmos_jl_mod_src_cache_microphysics_cache_jl_L1096
   250.34 ms  [ 990 invocations] run_field_matrix_solver__FILE_ClimaCore_vVw9t_src_MatrixFields_field_matrix_solver_jl_L354
    87.53 ms  [4185 invocations] regridded_snapshot__FILE_ClimaUtilities_5ClV5_ext_DataHandlingExt_jl_L666
Selected baseline kernel: set_microphysics_tendency_cache__FILE_ClimaAtmos_jl_src_cache_microphysics_cache_jl_L1096 (total=460.15 ms)
Selected modified kernel: set_microphysics_tendency_cache__FILE_ClimaAtmos_jl_mod_src_cache_microphysics_cache_jl_L1096 (total=254.31 ms, score=0.978)


In [7]:
# Get statistics from both databases
if baseline_kernel_name is None or mod_kernel_name is None:
    print("Error: Missing kernel name selection; check the listing above.")
else:
    baseline_stats, baseline_rows = get_kernel_stats(
        baseline_db, baseline_kernel_name
    )
    mod_stats, mod_rows = get_kernel_stats(mod_db, mod_kernel_name)

    # Compare and display results
    compare_stats(baseline_stats, mod_stats)


KERNEL: set_microphysics_tendency_cache__FILE_ClimaAtmos_jl_src_cache_microphysics_cache_jl_L1096



Unnamed: 0,Metric,Baseline,Modified,% Change
0,Invocations,25,25,0.00%
1,Total Duration,460.15 ms,254.31 ms,-44.73%
2,Mean Duration,18.41 ms,10.17 ms,-44.73%
3,Median Duration,18.40 ms,10.16 ms,-44.76%
4,Min Duration,18.36 ms,10.14 ms,-44.80%
5,Max Duration,18.49 ms,10.36 ms,-43.98%
6,Grid Dimensions,"<4, 4, 1536>","<4, 4, 1536>",—
7,Block Dimensions,"<64, 1, 1>","<64, 1, 1>",—
8,Block Size (threads),64,64,—
9,Grid Size (blocks),24576,24576,—



✅ Modified version is 44.73% FASTER (mean duration)


In [8]:
# Compare SYPD

log_path_template = "../.calkit/slurm/logs/nsys-{case}.out"


def get_sypd_time_series(case) -> list[float]:
    with open(log_path_template.format(case=case)) as f:
        lines = f.readlines()
    sypd_time_series = []
    for line in lines:
        if "estimated_sypd =" in line and "Inf" not in line:
            sypd_time_series.append(float(line.split()[-1].replace('"', "")))
    return sypd_time_series


df = pd.DataFrame(
    {
        "baseline": get_sypd_time_series("baseline"),
        "mod": get_sypd_time_series("mod"),
    }
)
df["diff_pct"] = (df["mod"] - df["baseline"]) / df["baseline"] * 100
df

Unnamed: 0,baseline,mod,diff_pct
0,0.254,0.271,6.692913
1,0.327,0.35,7.033639
2,0.401,0.43,7.23192


In [9]:
# --- Total profiling duration analysis ---
def get_total_duration(db_path):
    """Get total duration of the profiling session."""
    import sqlite3

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    query = """
    SELECT MIN(start) as first_event, MAX(end) as last_event
    FROM CUPTI_ACTIVITY_KIND_KERNEL
    """
    cursor.execute(query)
    result = cursor.fetchone()
    conn.close()
    if result[0] is not None and result[1] is not None:
        return result[1] - result[0]  # nanoseconds
    return None


baseline_duration = get_total_duration(baseline_db)
mod_duration = get_total_duration(mod_db)

assert baseline_duration is not None, "Could not determine baseline duration"
assert mod_duration is not None, "Could not determine modified duration"

print("\nTotal Profiling Duration:")
print(f"  Baseline: {format_duration(baseline_duration)}")
print(f"  Modified: {format_duration(mod_duration)}")

if mod_duration > baseline_duration:
    overhead = ((mod_duration - baseline_duration) / baseline_duration) * 100
    print(f"  ⚠️  Modified run is {overhead:.2f}% LONGER overall")
else:
    improvement = (
        (baseline_duration - mod_duration) / baseline_duration
    ) * 100
    print(f"  ✅ Modified run is {improvement:.2f}% SHORTER overall")


# --- Total GPU kernel time analysis ---
def get_total_gpu_time(db_path):
    """Get total GPU time summed across all kernels."""
    import sqlite3

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    query = """
    SELECT SUM(end - start) as total_gpu_ns
    FROM CUPTI_ACTIVITY_KIND_KERNEL
    """
    cursor.execute(query)
    result = cursor.fetchone()
    conn.close()
    return result[0] if result[0] is not None else 0


baseline_gpu_time = get_total_gpu_time(baseline_db)
mod_gpu_time = get_total_gpu_time(mod_db)

print("\nTotal GPU Kernel Time:")
print(f"  Baseline: {format_duration(baseline_gpu_time)}")
print(f"  Modified: {format_duration(mod_gpu_time)}")

gpu_time_diff = ((mod_gpu_time - baseline_gpu_time) / baseline_gpu_time) * 100
sign = "+" if gpu_time_diff > 0 else ""
print(f"  Change: {sign}{gpu_time_diff:.2f}%")

# Calculate "idle" time (wall time - GPU time)
baseline_idle = baseline_duration - baseline_gpu_time  # type: ignore
mod_idle = mod_duration - mod_gpu_time  # type: ignore

print("\nCPU/Sync Overhead (Total Duration - GPU Time):")
print(f"  Baseline: {format_duration(baseline_idle)}")
print(f"  Modified: {format_duration(mod_idle)}")

if mod_idle > baseline_idle:
    overhead_increase = ((mod_idle - baseline_idle) / baseline_idle) * 100
    print(f"  ⚠️  Modified has {overhead_increase:.2f}% MORE overhead")
else:
    overhead_decrease = ((baseline_idle - mod_idle) / baseline_idle) * 100
    print(f"  ✅ Modified has {overhead_decrease:.2f}% LESS overhead")


# --- Kernel count and new/missing kernels ---
def get_all_kernel_names(db_path):
    """Get all unique kernel names in the database."""
    import sqlite3

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    query = """
    SELECT DISTINCT s.value
    FROM CUPTI_ACTIVITY_KIND_KERNEL k
    JOIN StringIds s ON k.demangledName = s.id
    """
    cursor.execute(query)
    results = {row[0] for row in cursor.fetchall()}
    conn.close()
    return results


baseline_kernel_set = get_all_kernel_names(baseline_db)
mod_kernel_set = get_all_kernel_names(mod_db)

# Remove _mod_ from names for comparison
mod_kernel_set = {name.replace("_mod_", "_") for name in mod_kernel_set}

print("\nKernel Count Comparison:")
print(f"  Baseline: {len(baseline_kernel_set)} unique kernels")
print(f"  Modified: {len(mod_kernel_set)} unique kernels")

new_kernels = mod_kernel_set - baseline_kernel_set
removed_kernels = baseline_kernel_set - mod_kernel_set

if new_kernels:
    print(f"\n⚠️  {len(new_kernels)} NEW kernels in modified:")
    for kernel in sorted(new_kernels)[:10]:
        print(f"    + {kernel}")
    if len(new_kernels) > 10:
        print(f"    ... and {len(new_kernels) - 10} more")

if removed_kernels:
    print(f"\n✅ {len(removed_kernels)} kernels REMOVED in modified:")
    for kernel in sorted(removed_kernels)[:10]:
        print(f"    - {kernel}")
    if len(removed_kernels) > 10:
        print(f"    ... and {len(removed_kernels) - 10} more")


# --- Kernel invocation count changes ---
def get_kernel_invocation_counts(db_path):
    """Get invocation counts for all kernels."""
    import sqlite3

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    query = """
    SELECT s.value, COUNT(*) as count
    FROM CUPTI_ACTIVITY_KIND_KERNEL k
    JOIN StringIds s ON k.demangledName = s.id
    GROUP BY s.value
    ORDER BY count DESC
    """
    cursor.execute(query)
    results = {row[0]: row[1] for row in cursor.fetchall()}
    conn.close()
    return results


baseline_counts = get_kernel_invocation_counts(baseline_db)
mod_counts = get_kernel_invocation_counts(mod_db)

common_kernels = set(baseline_counts.keys()) & set(mod_counts.keys())

print("\nKernel Invocation Changes (for common kernels):")
significant_changes = []
for kernel in common_kernels:
    baseline_count = baseline_counts[kernel]
    mod_count = mod_counts[kernel]
    if baseline_count != mod_count:
        pct_change = ((mod_count - baseline_count) / baseline_count) * 100
        significant_changes.append(
            (abs(pct_change), kernel, baseline_count, mod_count, pct_change)
        )
significant_changes.sort(reverse=True)
if significant_changes:
    print("\nTop 10 kernels with changed invocation counts:")
    for (
        _,
        kernel,
        baseline_count,
        mod_count,
        pct_change,
    ) in significant_changes[:10]:
        sign = "+" if pct_change > 0 else ""
        print(
            f"  {sign}{pct_change:+6.1f}%: {baseline_count:>6} → {mod_count:>6} : {kernel[:80]}"
        )
else:
    print("  ✓ All common kernels have identical invocation counts")


# --- Kernel family analysis ---
def aggregate_by_pattern(counts_dict, duration_dict, pattern):
    """Aggregate kernels matching a pattern."""
    total_count = 0
    total_duration = 0
    for kernel_name in counts_dict:
        if pattern in kernel_name:
            total_count += counts_dict[kernel_name]
            total_duration += duration_dict.get(kernel_name, 0)
    return total_count, total_duration


def get_kernel_durations(db_path):
    """Get total duration for each kernel."""
    import sqlite3

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    query = """
    SELECT s.value, SUM(k.end - k.start) as total_ns
    FROM CUPTI_ACTIVITY_KIND_KERNEL k
    JOIN StringIds s ON k.demangledName = s.id
    GROUP BY s.value
    """
    cursor.execute(query)
    results = {row[0]: row[1] for row in cursor.fetchall()}
    conn.close()
    return results


baseline_durations = get_kernel_durations(baseline_db)
mod_durations = get_kernel_durations(mod_db)

patterns_to_check = [
    "set_prognostic_edmf_precomputed_quantities",
    "microphysics",
    "precipitation",
    "tendency",
    "implicit",
    "ldiv",
    "Wfact",
]

print("\nKernel Family Analysis (by name pattern):")
for pattern in patterns_to_check:
    baseline_count, baseline_duration = aggregate_by_pattern(
        baseline_counts, baseline_durations, pattern
    )
    mod_count, mod_duration = aggregate_by_pattern(
        mod_counts, mod_durations, pattern
    )
    if baseline_count > 0 or mod_count > 0:
        count_change = (
            ((mod_count - baseline_count) / max(baseline_count, 1)) * 100
            if baseline_count > 0
            else float("inf")
        )
        duration_change = (
            ((mod_duration - baseline_duration) / max(baseline_duration, 1))
            * 100
            if baseline_duration > 0
            else float("inf")
        )
        print(f"\n  Pattern: '{pattern}'")
        print(
            f"    Invocations: {baseline_count:>6} → {mod_count:>6} ({count_change:+.1f}%)"
        )
        print(
            f"    Total time:  {format_duration(baseline_duration):>10} → {format_duration(mod_duration):>10} ({duration_change:+.1f}%)"
        )


Total Profiling Duration:
  Baseline: 3.28 s
  Modified: 3.29 s
  ⚠️  Modified run is 0.13% LONGER overall

Total GPU Kernel Time:
  Baseline: 3.18 s
  Modified: 2.97 s
  Change: -6.42%

CPU/Sync Overhead (Total Duration - GPU Time):
  Baseline: 108.08 ms
  Modified: 316.09 ms
  ⚠️  Modified has 192.45% MORE overhead

Kernel Count Comparison:
  Baseline: 265 unique kernels
  Modified: 265 unique kernels

Kernel Invocation Changes (for common kernels):
  ✓ All common kernels have identical invocation counts



Kernel Family Analysis (by name pattern):

  Pattern: 'set_prognostic_edmf_precomputed_quantities'
    Invocations:   1245 →   1245 (+0.0%)
    Total time:    46.11 ms →   46.34 ms (+0.5%)

  Pattern: 'microphysics'
    Invocations:   2480 →   2480 (+0.0%)
    Total time:   545.57 ms →  340.10 ms (-37.7%)

  Pattern: 'precipitation'
    Invocations:   1580 →   1580 (+0.0%)
    Total time:    50.08 ms →   50.37 ms (+0.6%)

  Pattern: 'tendency'
    Invocations:   6695 →   6695 (+0.0%)
    Total time:   911.62 ms →  705.78 ms (-22.6%)

  Pattern: 'implicit'
    Invocations:  11810 →  11810 (+0.0%)
    Total time:   807.20 ms →  807.53 ms (+0.0%)
