In [12]:
import torch
import math
import os
import time

In [13]:
# Scaled Dot Product Attention (baseline)

def sdpa_baseline(Q, K, V):
    """
    Q: (B, H, S, D)
    K: (B, H, S, D)
    V: (B, H, S, D)
    """
    B, H, S, D = Q.shape

    # 1) scores = Q @ K^T
    scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(D)  # (B, H, S, S)

    # 2) softmax
    attn = torch.softmax(scores, dim=-1)

    # 3) out = attn @ V
    out = torch.matmul(attn, V)  # (B, H, S, D)

    return out

In [14]:
# RAPL ENERGY READING FUNCTIONS
def read_energy_uj(path):
    with open(path, 'r') as f:
        return int(f.read().strip())

def find_rapl_sensors():
    """Return list of RAPL energy sensors (energy_uj files)."""
    base = "/sys/class/powercap/intel-rapl"
    sensors = []
    if not os.path.exists(base):
        print("RAPL not available on this system.")
        return sensors

    for root, dirs, files in os.walk(base):
        for name in files:
            if name == "energy_uj":
                sensors.append(os.path.join(root, name))
    return sensors

RAPL_SENSORS = find_rapl_sensors()

def measure_energy_j(func, *args, **kwargs):
    """Run func and measure elapsed time + energy across all RAPL sensors."""
    if not RAPL_SENSORS:
        raise RuntimeError("No RAPL sensors found!")

    # read initial energy
    e_before = [read_energy_uj(s) for s in RAPL_SENSORS]

    # run function
    t0 = time.perf_counter()
    out = func(*args, **kwargs)
    t1 = time.perf_counter()

    # read final energy
    e_after = [read_energy_uj(s) for s in RAPL_SENSORS]

    # convert to Joules
    e_joules = []
    for before, after, sensor in zip(e_before, e_after, RAPL_SENSORS):
        # handle wrap-around
        if after < before:
            diff = (after + (1 << 32)) - before
        else:
            diff = after - before
        e_joules.append(diff / 1e6)  # µJ → J

    exec_time = t1 - t0
    return out, exec_time, e_joules, RAPL_SENSORS

# FLOPs for attention
def attention_flops(B, H, S, D):
    """
    FLOPs breakdown:

    1) QK^T: B * H * S * D * S * 2
    2) softmax: negligible for FLOPs counting (you can include if needed)
    3) (softmax) @ V: B * H * S * S * D * 2

    We return total FLOPs.
    """
    qk_flops = 2 * B * H * S * D * S
    av_flops = 2 * B * H * S * S * D
    return qk_flops + av_flops



def bench_attention(name, fn, Q, K, V, B, H, S, D):
    out, t, energies, sensors = measure_energy_j(fn, Q, K, V)

    total_energy = sum(energies)
    avg_power = total_energy / t
    flops = attention_flops(B, H, S, D)

    gflops_s = flops / t / 1e9
    eff = flops / total_energy / 1e9

    print(f"\n=== {name} ===")
    print(f"Time: {t*1000:.2f} ms")
    print(f"Total Energy: {total_energy:.4f} J")
    print(f"Average Power: {avg_power:.2f} W")
    print(f"Throughput: {gflops_s:.2f} GFLOP/s")
    print(f"Energy Efficiency: {eff:.4f} GFLOP/J")
    return out


In [15]:
# Run Attention Benchmark
B, H, S, D = 1, 1, 256, 64

Q = torch.randn((B, H, S, D), dtype=torch.float32)
K = torch.randn((B, H, S, D), dtype=torch.float32)
V = torch.randn((B, H, S, D), dtype=torch.float32)

# Baseline attention
out_attn = bench_attention("scaled dot-product attention", sdpa_baseline,
                           Q, K, V, B, H, S, D)



=== scaled dot-product attention ===
Time: 14.72 ms
Total Energy: 1.5965 J
Average Power: 108.45 W
Throughput: 1.14 GFLOP/s
Energy Efficiency: 0.0105 GFLOP/J
