# TaskState Branch Comparison

This notebook tests how TaskState.nbytes changes with different numbers of branches read.

## Hypothesis

If we read more branches from ROOT files:
- Coffea's `bytesread` should increase (more data read from disk)
- TaskState's `nbytes` (output size) should also increase (larger results)

## Approach

1. Create processors that read 1, 2, 5, and 10 branches
2. Monitor TaskState during each run
3. Compare task output sizes across runs
4. Compare with Coffea's bytesread

In [1]:
import datetime
import threading
import time

import awkward as ak
import coffea
import matplotlib.pyplot as plt
import pandas as pd
import skhep_testdata
from coffea import processor
from coffea.nanoevents import NanoAODSchema
from dask.distributed import Client, LocalCluster

coffea.__version__

'2025.11.0'

## Task Monitor (from previous notebook)

In [2]:
class TaskMonitor:
    """Monitor Dask scheduler tasks in real-time."""

    def __init__(self, scheduler, interval=0.1):
        self.scheduler = scheduler
        self.interval = interval
        self.snapshots = []
        self.running = False
        self.thread = None

    def start(self):
        """Start monitoring."""
        self.running = True
        self.thread = threading.Thread(target=self._monitor_loop, daemon=True)
        self.thread.start()

    def stop(self):
        """Stop monitoring."""
        self.running = False
        if self.thread:
            self.thread.join(timeout=5)

    def _monitor_loop(self):
        """Background monitoring loop."""
        while self.running:
            try:
                snapshot = self._capture_snapshot()
                if snapshot["tasks"]:
                    self.snapshots.append(snapshot)
            except Exception as e:
                print(f"Error in monitoring: {e}")

            time.sleep(self.interval)

    def _capture_snapshot(self):
        """Capture current scheduler state."""
        timestamp = datetime.datetime.now()

        tasks = []
        for task_key, task_state in self.scheduler.tasks.items():
            worker = None
            if hasattr(task_state, "who_has") and task_state.who_has:
                worker = (
                    list(task_state.who_has)[0].address if task_state.who_has else None
                )

            tasks.append(
                {
                    "key": task_key,
                    "state": task_state.state,
                    "worker": worker,
                    "nbytes": task_state.nbytes if hasattr(task_state, "nbytes") else 0,
                    "type": task_state.type if hasattr(task_state, "type") else None,
                }
            )

        return {
            "timestamp": timestamp,
            "num_tasks": len(tasks),
            "num_workers": len(self.scheduler.workers),
            "tasks": tasks,
        }

    def get_all_tasks(self):
        """Get all unique tasks seen across all snapshots."""
        all_tasks = {}

        for snapshot in self.snapshots:
            for task in snapshot["tasks"]:
                key = task["key"]
                if key not in all_tasks or task["nbytes"] > all_tasks[key]["nbytes"]:
                    all_tasks[key] = task

        return list(all_tasks.values())

    def reset(self):
        """Clear snapshots for next run."""
        self.snapshots = []

## Create Processors with Different Branch Counts

In [3]:
from memory_profiler import profile


class BranchTestProcessor(processor.ProcessorABC):
    """Processor that reads a specific number of branches."""

    def __init__(self, num_branches: int):
        self.num_branches = num_branches

    @profile
    def process(self, events):
        results = {}

        # Always count events
        results["nevents"] = len(events)

        # Read increasing numbers of branches
        if self.num_branches >= 1:
            # Branch 1: Jet pt
            jets = events.Jet[events.Jet.pt > 30]
            results["njets"] = ak.sum(ak.num(jets))
            results["mean_jet_pt"] = ak.mean(jets.pt) if len(jets) > 0 else 0

        if self.num_branches >= 2:
            # Branch 2: Jet eta
            results["mean_jet_eta"] = ak.mean(jets.eta) if len(jets) > 0 else 0

        if self.num_branches >= 3:
            # Branch 3: Jet phi
            results["mean_jet_phi"] = ak.mean(jets.phi) if len(jets) > 0 else 0

        if self.num_branches >= 4:
            # Branch 4: Jet mass
            results["mean_jet_mass"] = ak.mean(jets.mass) if len(jets) > 0 else 0

        if self.num_branches >= 5:
            # Branch 5: Electron pt
            electrons = events.Electron[events.Electron.pt > 20]
            results["nelectrons"] = ak.sum(ak.num(electrons))
            results["mean_electron_pt"] = (
                ak.mean(electrons.pt) if len(electrons) > 0 else 0
            )

        if self.num_branches >= 6:
            # Branch 6: Electron eta
            results["mean_electron_eta"] = (
                ak.mean(electrons.eta) if len(electrons) > 0 else 0
            )

        if self.num_branches >= 7:
            # Branch 7: Muon pt
            muons = events.Muon[events.Muon.pt > 20]
            results["nmuons"] = ak.sum(ak.num(muons))
            results["mean_muon_pt"] = ak.mean(muons.pt) if len(muons) > 0 else 0

        if self.num_branches >= 8:
            # Branch 8: Muon eta
            results["mean_muon_eta"] = ak.mean(muons.eta) if len(muons) > 0 else 0

        if self.num_branches >= 9:
            # Branch 9: MET pt
            results["mean_met_pt"] = ak.mean(events.MET.pt)

        if self.num_branches >= 10:
            # Branch 10: MET phi
            results["mean_met_phi"] = ak.mean(events.MET.phi)

        return results

    def postprocess(self, accumulator):
        return accumulator

## Start Cluster

In [4]:
# Start cluster
cluster = LocalCluster(n_workers=2, threads_per_worker=1, processes=True)
client = Client(cluster)

print(f"Dashboard: {client.dashboard_link}")
print(f"Workers: {len(client.scheduler_info()['workers'])}")

# Get test file
test_file = skhep_testdata.data_path("nanoAOD_2015_CMS_Open_Data_ttbar.root")

# Create fileset
fileset = {
    "ttbar": {
        "files": {test_file: "Events"},
    },
}

Perhaps you already have a cluster running?
Hosting the HTTP server on port 64127 instead


Dashboard: http://127.0.0.1:64127/status
Workers: 2


## Run Experiments with Different Branch Counts

In [None]:
from pprint import pprint

# Test with different numbers of branches
branch_counts = [1, 2, 3, 5, 10]

results = []

# Create monitor once
monitor = TaskMonitor(client.cluster.scheduler, interval=0.1)

for num_branches in branch_counts:
    print(f"\n{'=' * 60}")
    print(f"Testing with {num_branches} branches")
    print(f"{'=' * 60}")

    # Reset monitor for this run
    monitor.reset()
    monitor.start()

    # Run processor
    proc = BranchTestProcessor(num_branches=num_branches)
    executor = processor.DaskExecutor(client=client)
    runner = processor.Runner(
        executor=executor,
        savemetrics=True,
        schema=NanoAODSchema,
    )

    output, report = runner(
        fileset,
        treename="Events",
        processor_instance=proc,
    )

    pprint(len(report["columns"]))
    pprint(report["columns"])
    # Wait for final tasks
    time.sleep(1)
    monitor.stop()

    # Get task data
    all_tasks = monitor.get_all_tasks()
    processor_tasks = [
        task
        for task in all_tasks
        if "BranchTestProcessor" in str(task["key"])
        or "process" in str(task["key"]).lower()
    ]

    # Calculate statistics
    task_sizes = [task["nbytes"] for task in processor_tasks if task["nbytes"]]

    result = {
        "num_branches": num_branches,
        "coffea_bytesread": report["bytesread"],
        "coffea_bytesread_mb": report["bytesread"] / 1e6,
        "num_processor_tasks": len(processor_tasks),
        "total_task_bytes": sum(task_sizes),
        "total_task_bytes_kb": sum(task_sizes) / 1e3,
        "mean_task_bytes": sum(task_sizes) / len(task_sizes) if task_sizes else 0,
        "mean_task_bytes_kb": (sum(task_sizes) / len(task_sizes) / 1e3)
        if task_sizes
        else 0,
        "chunks": report["chunks"],
        "events": report["entries"],
    }

    results.append(result)

    print("\nResults:")
    print(f"  Coffea bytesread: {result['coffea_bytesread_mb']:.2f} MB")
    print(f"  Processor tasks: {result['num_processor_tasks']}")
    print(f"  Total task output: {result['total_task_bytes_kb']:.2f} KB")
    print(f"  Mean task output: {result['mean_task_bytes_kb']:.2f} KB")
    print(f"  Chunks: {result['chunks']}")
    print(f"  Events: {result['events']}")


Testing with 1 branches


Output()

Output()

ERROR: Could not find file /var/folders/0v/cvfdg7wn2k59f7d0pp1wdsy80000gn/T/ipykernel_8716/1640942170.py




2
[Accessed(branch='nJet', buffer_key='d48060b2-6a57-11ed-8e14-0600a8c0beef/%2FEvents%3B1/0-200/offsets/nJet%2C%21load%2C%21counts2offsets%2C%21skip%2C%21offsets'),
 Accessed(branch='Jet_pt', buffer_key='d48060b2-6a57-11ed-8e14-0600a8c0beef/%2FEvents%3B1/0-200/data/Jet_pt%2C%21load%2C%21content')]


KeyError: 'mem_usage_bytes'

## Summary Table

In [None]:
df = pd.DataFrame(results)

print("\n=== Summary Table ===")
print(
    df[
        [
            "num_branches",
            "coffea_bytesread_mb",
            "total_task_bytes_kb",
            "mean_task_bytes_kb",
            "num_processor_tasks",
            "chunks",
        ]
    ].to_string(index=False)
)

# Calculate ratios
print("\n=== Scaling Analysis ===")
if len(df) > 1:
    baseline_bytesread = df.iloc[0]["coffea_bytesread_mb"]
    baseline_task_bytes = df.iloc[0]["total_task_bytes_kb"]

    df["bytesread_ratio"] = df["coffea_bytesread_mb"] / baseline_bytesread
    df["task_bytes_ratio"] = df["total_task_bytes_kb"] / baseline_task_bytes

    print(
        df[
            [
                "num_branches",
                "bytesread_ratio",
                "task_bytes_ratio",
            ]
        ].to_string(index=False)
    )

## Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Coffea bytesread vs branches
axes[0, 0].plot(
    df["num_branches"], df["coffea_bytesread_mb"], "o-", linewidth=2, markersize=8
)
axes[0, 0].set_xlabel("Number of Branches")
axes[0, 0].set_ylabel("Coffea Bytesread (MB)")
axes[0, 0].set_title("Coffea Bytesread vs Number of Branches")
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: TaskState total output vs branches
axes[0, 1].plot(
    df["num_branches"],
    df["total_task_bytes_kb"],
    "o-",
    linewidth=2,
    markersize=8,
    color="orange",
)
axes[0, 1].set_xlabel("Number of Branches")
axes[0, 1].set_ylabel("Total Task Output (KB)")
axes[0, 1].set_title("TaskState Total Output vs Number of Branches")
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Mean task output vs branches
axes[1, 0].plot(
    df["num_branches"],
    df["mean_task_bytes_kb"],
    "o-",
    linewidth=2,
    markersize=8,
    color="green",
)
axes[1, 0].set_xlabel("Number of Branches")
axes[1, 0].set_ylabel("Mean Task Output (KB)")
axes[1, 0].set_title("Mean Task Output vs Number of Branches")
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Scaling ratios
if len(df) > 1:
    axes[1, 1].plot(
        df["num_branches"],
        df["bytesread_ratio"],
        "o-",
        linewidth=2,
        markersize=8,
        label="Coffea Bytesread",
    )
    axes[1, 1].plot(
        df["num_branches"],
        df["task_bytes_ratio"],
        "s-",
        linewidth=2,
        markersize=8,
        label="TaskState Output",
    )
    axes[1, 1].set_xlabel("Number of Branches")
    axes[1, 1].set_ylabel("Ratio (relative to 1 branch)")
    axes[1, 1].set_title("Scaling Ratios")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Analysis and Conclusions

In [None]:
print("=== Key Findings ===")
print()

if len(df) > 1:
    # Calculate correlation
    correlation = df[["coffea_bytesread_mb", "total_task_bytes_kb"]].corr().iloc[0, 1]

    print(
        f"1. Correlation between Coffea bytesread and TaskState output: {correlation:.3f}"
    )
    print()

    # Calculate average scaling
    avg_bytesread_increase = df["bytesread_ratio"].mean()
    avg_task_bytes_increase = df["task_bytes_ratio"].mean()

    print(f"2. Average bytesread scaling factor: {avg_bytesread_increase:.2f}x")
    print(f"   Average task output scaling factor: {avg_task_bytes_increase:.2f}x")
    print()

    # Check if they scale similarly
    ratio_difference = abs(df["bytesread_ratio"] - df["task_bytes_ratio"]).mean()

    print(f"3. Mean difference in scaling ratios: {ratio_difference:.3f}")
    if ratio_difference < 0.2:
        print("   → TaskState output scales similarly to Coffea bytesread ✅")
    else:
        print("   → TaskState output scales differently from Coffea bytesread ⚠️")
    print()

    print("4. Interpretation:")
    print("   - TaskState.nbytes measures OUTPUT size (pickled result)")
    print("   - Coffea bytesread measures INPUT size (compressed ROOT data)")
    if correlation > 0.8:
        print("   - Strong correlation suggests task output scales with input data")
        print("   - TaskState could be useful as a PROXY for relative chunk sizes")
    else:
        print("   - Weak correlation suggests task output doesn't directly track input")
        print("   - TaskState may not be reliable for chunk size estimation")
    print()

    print("5. Limitations:")
    print("   - Absolute values differ greatly (KB vs MB)")
    print("   - Cannot calculate actual bytes per chunk from TaskState alone")
    print("   - Still need decorator approach for true per-chunk input tracking")

## Cleanup

In [None]:
client.close()
cluster.close()
print("Cluster closed")