# Real-Time TaskState Exploration

This notebook captures Dask TaskState data **during execution** before tasks are cleaned up.

## Approach

1. Create a background thread that periodically snapshots scheduler state
2. Run Coffea processing
3. Analyze captured task data

In [1]:
import datetime
import threading
import time
from collections import defaultdict

import awkward as ak
import pandas as pd
import skhep_testdata
from coffea import processor
from coffea.nanoevents import NanoAODSchema
from dask.distributed import Client, LocalCluster

## Create Task Monitoring Class

In [None]:
class TaskMonitor:
    """Monitor Dask scheduler tasks in real-time."""

    def __init__(self, scheduler, interval=0.1):
        self.scheduler = scheduler
        self.interval = interval
        self.snapshots = []
        self.running = False
        self.thread = None

    def start(self):
        """Start monitoring."""
        self.running = True
        self.thread = threading.Thread(target=self._monitor_loop, daemon=True)
        self.thread.start()
        print(f"Started task monitoring (interval={self.interval}s)")

    def stop(self):
        """Stop monitoring."""
        self.running = False
        if self.thread:
            self.thread.join(timeout=5)
        print(f"Stopped monitoring. Captured {len(self.snapshots)} snapshots")

    def _monitor_loop(self):
        """Background monitoring loop."""
        while self.running:
            try:
                snapshot = self._capture_snapshot()
                if snapshot["tasks"]:
                    self.snapshots.append(snapshot)
            except Exception as e:
                print(f"Error in monitoring: {e}")

            time.sleep(self.interval)

    def _capture_snapshot(self):
        """Capture current scheduler state."""
        timestamp = datetime.datetime.now()

        tasks = []
        for task_key, task_state in self.scheduler.tasks.items():
            # Get worker info
            worker = None
            if hasattr(task_state, "who_has") and task_state.who_has:
                worker = (
                    list(task_state.who_has)[0].address if task_state.who_has else None
                )

            tasks.append(
                {
                    "key": task_key,
                    "state": task_state.state,
                    "worker": worker,
                    "nbytes": task_state.nbytes if hasattr(task_state, "nbytes") else 0,
                    "type": task_state.type if hasattr(task_state, "type") else None,
                }
            )

        return {
            "timestamp": timestamp,
            "num_tasks": len(tasks),
            "num_workers": len(self.scheduler.workers),
            "tasks": tasks,
        }

    def get_all_tasks(self):
        """Get all unique tasks seen across all snapshots."""
        all_tasks = {}

        for snapshot in self.snapshots:
            for task in snapshot["tasks"]:
                key = task["key"]
                # Keep the most complete version of each task
                if key not in all_tasks or task["nbytes"] > all_tasks[key]["nbytes"]:
                    all_tasks[key] = task

        return list(all_tasks.values())

## Create Simple Processor

In [None]:
class SimpleProcessor(processor.ProcessorABC):
    """Simple processor for testing."""

    def process(self, events):
        # Do some computation
        jets = events.Jet[events.Jet.pt > 30]

        return {
            "nevents": len(events),
            "njets": ak.sum(ak.num(jets)),
            "mean_pt": ak.mean(jets.pt) if len(jets) > 0 else 0,
            "mean_eta": ak.mean(jets.eta) if len(jets) > 0 else 0,
        }

    def postprocess(self, accumulator):
        return accumulator

## Start Cluster and Monitor

In [18]:
# Start cluster
cluster = LocalCluster(n_workers=4, threads_per_worker=1, processes=True)
client = Client(cluster)

print(f"Dashboard: {client.dashboard_link}")
print(f"Workers: {len(client.scheduler_info()['workers'])}")

# Start monitoring
monitor = TaskMonitor(client.cluster.scheduler, interval=0.01)
monitor.start()

# Give it a moment to start
time.sleep(0.5)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 61828 instead


Dashboard: http://127.0.0.1:61828/status
Workers: 4
Started task monitoring (interval=0.01s)


## Run Coffea Processing (While Monitoring)

In [19]:
# Get test file
test_file = skhep_testdata.data_path("nanoAOD_2015_CMS_Open_Data_ttbar.root")

# Create fileset
fileset = {
    "ttbar": {
        "files": {test_file: "Events"},
    },
}

# Run processor
proc = SimpleProcessor()
executor = processor.DaskExecutor(client=client)
runner = processor.Runner(
    executor=executor,
    savemetrics=True,
    schema=NanoAODSchema,
)

print("Starting Coffea processing...")
output, report = runner(
    fileset,
    treename="Events",
    processor_instance=proc,
)

print(f"\nProcessed {report['entries']} events in {report['chunks']} chunks")
print(f"Total bytes read: {report['bytesread'] / 1e6:.2f} MB")

# Wait a bit for final tasks to be captured
time.sleep(1)

# Stop monitoring
monitor.stop()

Output()

Starting Coffea processing...





Processed 200 events in 1 chunks
Total bytes read: 0.34 MB
Stopped monitoring. Captured 67 snapshots


## Analyze Captured Data

In [20]:
print("=== Monitoring Summary ===")
print(f"Total snapshots: {len(monitor.snapshots)}")
print(
    f"Duration: {(monitor.snapshots[-1]['timestamp'] - monitor.snapshots[0]['timestamp']).total_seconds():.2f}s"
)

# Show task counts over time
print("\nTask count over time:")
for i, snapshot in enumerate(
    monitor.snapshots[:: max(1, len(monitor.snapshots) // 10)]
):
    print(
        f"  {i}: {snapshot['num_tasks']} tasks at {snapshot['timestamp'].strftime('%H:%M:%S.%f')[:-3]}"
    )

=== Monitoring Summary ===
Total snapshots: 67
Duration: 0.76s

Task count over time:
  0: 2 tasks at 14:11:00.506
  1: 2 tasks at 14:11:00.578
  2: 2 tasks at 14:11:00.646
  3: 2 tasks at 14:11:00.714
  4: 2 tasks at 14:11:00.783
  5: 2 tasks at 14:11:00.850
  6: 2 tasks at 14:11:00.917
  7: 2 tasks at 14:11:00.984
  8: 2 tasks at 14:11:01.054
  9: 2 tasks at 14:11:01.126
  10: 2 tasks at 14:11:01.195
  11: 2 tasks at 14:11:01.265


## Extract All Unique Tasks

In [24]:
all_tasks = monitor.get_all_tasks()
print(f"=== Total unique tasks captured: {len(all_tasks)} ===")

# Filter for processor-related tasks
processor_tasks = [
    task
    for task in all_tasks
    # if 'SimpleProcessor' in str(task['key']) or 'process' in str(task['key']).lower()
]

print(f"Processor-related tasks: {len(processor_tasks)}")

# Show sample tasks
print("\n=== Sample Task Details ===")
for i, task in enumerate(processor_tasks):
    print(f"\nTask {i + 1}:")
    print(f"  Key: {task['key']}")
    print(f"  State: {task['state']}")
    print(f"  Worker: {task['worker']}")
    print(f"  Size: {task['nbytes'] / 1e3:.2f} KB" if task["nbytes"] else "  Size: N/A")
    print(f"  Type: {task['type']}")

=== Total unique tasks captured: 2 ===
Processor-related tasks: 2

=== Sample Task Details ===

Task 1:
  Key: lambda-479af8013eef8fd3a0f3694766170521
  State: memory
  Worker: tcp://127.0.0.1:61841
  Size: 1.20 KB
  Type: bytes

Task 2:
  Key: SimpleProcessor-6cc9379618d35fa267d3ddfd2223c89b
  State: memory
  Worker: tcp://127.0.0.1:61841
  Size: 0.71 KB
  Type: bytes


## Task Key Structure Analysis

In [None]:
print("=== Task Key Structure ===")

for i, task in enumerate(processor_tasks[:10]):
    key = task["key"]
    print(f"\n{i + 1}. {key}")
    print(f"   Type: {type(key)}")

    if isinstance(key, tuple):
        print(f"   Length: {len(key)}")
        for j, elem in enumerate(key):
            print(f"   [{j}]: {type(elem).__name__} = {elem}")

            # Check for file/dataset info
            if isinstance(elem, str) and any(
                x in elem for x in ["ttbar", "root", "nanoAOD", "file"]
            ):
                print("       ^ Contains file/dataset info!")

=== Task Key Structure ===

1. lambda-479af8013eef8fd3a0f3694766170521
   Type: <class 'str'>

2. SimpleProcessor-6cc9379618d35fa267d3ddfd2223c89b
   Type: <class 'str'>


## Per-Worker Task Distribution

In [None]:
print("=== Per-Worker Distribution ===")

worker_stats = defaultdict(lambda: {"count": 0, "total_bytes": 0, "tasks": []})

for task in processor_tasks:
    worker = task["worker"] or "unknown"
    worker_stats[worker]["count"] += 1
    worker_stats[worker]["total_bytes"] += task["nbytes"] if task["nbytes"] else 0
    worker_stats[worker]["tasks"].append(task)

for worker, stats in worker_stats.items():
    print(f"\nWorker: {worker}")
    print(f"  Tasks: {stats['count']}")
    print(f"  Total bytes: {stats['total_bytes'] / 1e6:.2f} MB")
    print(
        f"  Avg bytes/task: {stats['total_bytes'] / stats['count'] / 1e3:.2f} KB"
        if stats["count"] > 0
        else "  Avg: N/A"
    )

    # Show sample task keys
    print("  Sample tasks:")
    for task in stats["tasks"][:3]:
        size = f"{task['nbytes'] / 1e3:.1f} KB" if task["nbytes"] else "N/A"
        print(f"    {str(task['key'])[:60]}... : {size}")

=== Per-Worker Distribution ===

Worker: tcp://127.0.0.1:61841
  Tasks: 2
  Total bytes: 0.00 MB
  Avg bytes/task: 0.95 KB
  Sample tasks:
    lambda-479af8013eef8fd3a0f3694766170521... : 1.2 KB
    SimpleProcessor-6cc9379618d35fa267d3ddfd2223c89b... : 0.7 KB


## Create DataFrame for Analysis

In [16]:
df = pd.DataFrame(
    [
        {
            "key": str(task["key"])[:60] + "..."
            if len(str(task["key"])) > 60
            else str(task["key"]),
            "worker": task["worker"],
            "state": task["state"],
            "nbytes_kb": task["nbytes"] / 1e3 if task["nbytes"] else 0,
            "type": task["type"],
        }
        for task in processor_tasks
    ]
)

print("=== Task DataFrame ===")
print(df.head(10))

print("\n=== Size Statistics ===")
print(df[["nbytes_kb"]].describe())

print("\n=== Per-Worker Summary ===")
if len(df) > 0:
    print(df.groupby("worker")["nbytes_kb"].agg(["count", "sum", "mean", "std"]))

=== Task DataFrame ===
                                                key                 worker  \
0  SimpleProcessor-6cc9379618d35fa267d3ddfd2223c89b  tcp://127.0.0.1:61763   

    state  nbytes_kb   type  
0  memory      0.707  bytes  

=== Size Statistics ===
       nbytes_kb
count      1.000
mean       0.707
std          NaN
min        0.707
25%        0.707
50%        0.707
75%        0.707
max        0.707

=== Per-Worker Summary ===
                       count    sum   mean  std
worker                                         
tcp://127.0.0.1:61763      1  0.707  0.707  NaN


## Compare with Coffea Report

In [17]:
print("=== Comparison with Coffea Report ===")
print("\nCoffea report:")
print(f"  Chunks: {report['chunks']}")
print(f"  Bytes read: {report['bytesread'] / 1e6:.2f} MB")
print(f"  Events: {report['entries']}")

print("\nTaskState captured:")
print(f"  Processor tasks: {len(processor_tasks)}")
total_task_bytes = sum(task["nbytes"] for task in processor_tasks if task["nbytes"])
print(f"  Total task result sizes: {total_task_bytes / 1e6:.2f} MB")

print("\nNote: TaskState.nbytes is OUTPUT size (result), not INPUT bytes read!")
print(
    f"That's why task sizes ({total_task_bytes / 1e6:.2f} MB) differ from bytesread ({report['bytesread'] / 1e6:.2f} MB)"
)

=== Comparison with Coffea Report ===

Coffea report:
  Chunks: 1
  Bytes read: 0.34 MB
  Events: 200

TaskState captured:
  Processor tasks: 1
  Total task result sizes: 0.00 MB

Note: TaskState.nbytes is OUTPUT size (result), not INPUT bytes read!
That's why task sizes (0.00 MB) differ from bytesread (0.34 MB)


## Key Findings

### What We Successfully Captured:

1. ✅ **Task result sizes** - `TaskState.nbytes` for each task
2. ✅ **Worker attribution** - Which worker processed which task
3. ✅ **Task counts per worker** - Load distribution
4. ✅ **Task states** - Captured tasks during execution

### Critical Limitations:

1. ❌ **Input bytes** - Only output/result size, not bytes read from file
2. ❌ **Chunk identification** - Task keys are opaque, no file/dataset info
3. ❌ **Event counts** - Not available in TaskState
4. ⚠️ **Overhead** - Monitoring every 0.1s adds overhead

### Can We Use This for Bytes Per Chunk?

**Partially, but with major caveats:**

- We get **output bytes per task**, not input bytes read
- We can see **per-worker distribution** of task sizes
- We **cannot** map tasks to files/datasets without additional tracking
- We **cannot** calculate bytes per event (no event counts)
- We **cannot** calculate throughput per chunk (no input bytes)

### Recommendation:

**TaskState tracking provides partial data but is insufficient for your requirements.**

For complete per-chunk tracking with:
- Bytes per chunk (input)
- Worker attribution
- Bytes per event
- Throughput per chunk

**You need the `@track_metrics` decorator approach** that captures:
- Input metadata (file, dataset, events)
- Processing time
- Worker ID
- Memory/size estimates

TaskState could supplement this as automatic fallback, but cannot replace it.

## Cleanup

In [None]:
client.close()
cluster.close()
print("Cluster closed")