In [None]:
from pprint import pprint
import json
import pandas as pd
from TraceLens import TreePerfAnalyzer, GPUEventAnalyser

In [None]:
import torch
import torchvision.models as models
from torch.profiler import schedule
from torch.profiler import profile, ProfilerActivity

device = "cuda"
dtype  = torch.bfloat16
model  = models.resnet18().to(device).to(dtype)

B, C, H, W   = 16, 3, 224, 224
num_classes  = 1000

dummy_input  = torch.randn(B, C, H, W,
                           device=device, dtype=dtype)
dummy_target = torch.randn(B, num_classes,
                           device=device, dtype=dtype)

def train_step():
    # Single forward + backward pass.
    output = model(dummy_input)
    loss   = torch.nn.functional.mse_loss(output, dummy_target)
    loss.backward()

# test it out
train_step()

def warm_up(iters: int = 10):
    for _ in range(iters):
        train_step()
    torch.cuda.synchronize()


sched_wait, sched_warmup, sched_active, sched_repeat = 10, 5, 3, 2
sched = schedule(wait=sched_wait, 
                 warmup=sched_warmup, 
                 active=sched_active, 
                 repeat=sched_repeat)

def trace_handler(p):
    # this is called at the end of the active window
    # ``p.step_num`` is the last iteration of the *active* window.
    start = p.step_num - sched_active + 1
    end   = p.step_num
    p.export_chrome_trace(f"trace_iter{start}_{end}.json")

with profile(activities=[ProfilerActivity.CPU,
                         ProfilerActivity.CUDA],
             schedule=sched,
             record_shapes=True,
             with_stack=True,
             on_trace_ready=trace_handler) as p:
    warm_up()
    for _ in range(100):
        train_step()
        p.step()                   # marks iteration boundary

In [None]:
# for the nn module flops analysis we need with_stack=True 
# and also include it for TraceLens by setting add_python_func=True 
perf_analyzer = TreePerfAnalyzer.from_file("/home/ajassani/iLT_playground/trace_iter16_18.json", add_python_func=True)

In [None]:
"""
Utility helpers for walking a TraceLens event-tree and estimating total FLOPs
for a (sub)graph.  
Key idea:  
* If a CPU op has a registered performance model, we ask TraceLens for
  its predicted FLOPs and stop recursion for that branch.  
* If a CPU op is a *leaf* but **lacks** a model, we record it so the caller
  can decide what to do with the unmodelled work.  
* All other nodes (e.g. python functions, CPU ops with children, etc.)
  are traversed recursively.
"""

import logging
from typing import Dict, Any, List, Tuple
import TraceLens

# --------------------------------------------------------------------------- #
#  Logging configuration                                                      #
# --------------------------------------------------------------------------- #
# You can override this in the application’s entry-point if needed.
logging.basicConfig(
    level=logging.DEBUG,           # Change to INFO/ERROR in production.
    format="%(asctime)s | %(levelname)s | %(message)s",
)

# Module-level logger—importers can still reconfigure the root logger later.
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# Mapping: op_name → PerfModel subclass (populated by TraceLens at import time)
op_to_perf_model_class_map = TraceLens.PerfModel.op_to_perf_model_class_map


# --------------------------------------------------------------------------- #
#  Helper utilities                                                           #
# --------------------------------------------------------------------------- #
def is_leaf_cpu_op(trace_tree, event) -> bool:
    """
    Return ``True`` iff *event* is a CPU op **and** none of its children
    are CPU ops that launch GPU work.

    Parameters
    ----------
    trace_tree : TraceTree
        The TraceLens tree object that provides ``get_children_events``.
    event : dict
        A TraceLens event (host-side).

    Notes
    -----
    An op is not considered a leaf if any child has ``cat == "cpu_op"``
    **and** also owns GPU events—those children do real compute work that
    must be walked.
    """
    for child in trace_tree.get_children_events(event):
        if child.get("cat") == "cpu_op" and child.get("gpu_events") is not None:
            return False
    return True


# --------------------------------------------------------------------------- #
#  Public API                                                                 #
# --------------------------------------------------------------------------- #
def traverse_subtree_and_accumulate_flops(
    trace_tree,
    node: Dict[str, Any],
) -> Tuple[int, List[Dict[str, Any]]]:
    """
    Compute total FLOPs rooted at *node* and gather CPU leaf ops
    that lack a performance model.

    Returns
    -------
    total_flops : int
        Sum of FLOPs for all nodes under *node* that we could model.
    unaccounted_nodes : list[dict]
        Leaf CPU ops with unknown cost (useful for modelling gaps).
    """
    return _traverse_subtree_recursive(trace_tree, node)


# --------------------------------------------------------------------------- #
#  Internal recursive walker                                                  #
# --------------------------------------------------------------------------- #
def _traverse_subtree_recursive(
    perf_analyzer,
    node: Dict[str, Any],
) -> Tuple[int, List[Dict[str, Any]]]:

    trace_tree = perf_analyzer.tree

    # 1️⃣  Skip nodes that never touch the GPU (pure bookkeeping / python work).
    if node.get("gpu_events") is None:
        logger.debug("Skip non-GPU path  UID=%s  name=%s",
                     node.get("UID"), node.get("name"))
        return 0, []

    total_flops: int = 0
    unaccounted_nodes: List[Dict[str, Any]] = []

    cat = node.get("cat")
    logger.debug("Visit UID=%s  name=%s  cat=%s",
                 node.get("UID"), node.get("name"), cat)

    # 2️⃣  CPU-op branch ------------------------------------------------------ #
    if cat == "cpu_op":
        name = node["name"]

        # 2.1  We *do* have a model → use it and stop here.
        if name in op_to_perf_model_class_map:
            perf_metrics = perf_analyzer.compute_perf_metrics(node)
            flops_here = perf_metrics["GFLOPS"] * 1e9
            logger.debug("Model-hit  op=%s  FLOPs=%d", name, flops_here)
            return flops_here, unaccounted_nodes

        # 2.2  Leaf CPU-op with **no** model → record it, stop here.
        if is_leaf_cpu_op(trace_tree, node):
            logger.debug("Unmodelled leaf CPU-op  UID=%s  name=%s",
                         node.get("UID"), name)
            unaccounted_nodes.append(node)
            return 0, unaccounted_nodes

    # 3️⃣  Generic branch (python funcs, or CPU-op with children) ------------ #
    for child in trace_tree.get_children_events(node):
        child_flops, child_unacc = _traverse_subtree_recursive(perf_analyzer, child)
        total_flops += child_flops
        unaccounted_nodes.extend(child_unacc)

    logger.debug("Return UID=%s name=%s  cum_FLOPs=%d  unacc=%d",
                 node.get("UID"), node.get("name"),
                 total_flops, len(unaccounted_nodes))
    return total_flops, unaccounted_nodes


In [None]:
# 1. Compute the gpu event metrics for the nn.Module event
nn_module_event = next(e for e in perf_analyzer.tree.events if e.get('name')=='nn.Module: BasicBlock_2')
list_kernelUIDS = nn_module_event.get('gpu_events')
list_kernels = [perf_analyzer.tree.events_by_uid[uid] for uid in list_kernelUIDS]
gpu_time_metrics = GPUEventAnalyser(list_kernels).compute_metrics()
pprint(gpu_time_metrics)

In [None]:
total_flops, unaccounted_nodes = traverse_subtree_and_accumulate_flops(perf_analyzer, nn_module_event)
# unaccounted_nodes is a list of events that are cpu ops with no model
print(f"Total FLOPs: {total_flops}")
# print names of unaccounted nodes

unaccounted_nodes_names = [node['name'] for node in unaccounted_nodes]
print("Unaccounted nodes:")
for node in unaccounted_nodes:
    print(node['UID'], node['name'])
busy_tflops = (total_flops * 1e-12) / (gpu_time_metrics['busy_time'] * 1e-6)
busy_tflops
print(f"Busy TFLOPS: {busy_tflops:.2f} TFLOPS")