In [17]:
# this would be implemented in a package called pomm and these functions would be available after importing
# pomm
import contextlib, torch, random, time, types, threading
import torch.nn as nn
from functools import wraps
from torch.nn.modules.module import register_module_forward_hook

_THREAD = threading.local()

def sha256_cpu(x):
    import hashlib
    h = hashlib.sha256(); h.update(x.detach().cpu().numpy().tobytes())
    return h.hexdigest()

def _record(op, a, b, out, cfg):
    if random.random() > cfg.sample_rate:
        return
    t0 = time.perf_counter()
    h_in  = sha256_cpu(a)
    h_out = sha256_cpu(out)
    h_ms  = (time.perf_counter()-t0)*1000
    _THREAD.records.append(dict(op=op,
                                ain=a.shape, aout=out.shape,
                                hin=h_in, hout=h_out,
                                hash_ms=h_ms))

def _wrap_fn(fn, op_name, cfg):
    @wraps(fn)
    def wrapper(*args, **kw):
        out = fn(*args, **kw)
        _record(op_name, args[0], args[1], out, cfg)
        return out
    return wrapper



@contextlib.contextmanager
def verification(sample_rate=0.2):
    cfg = types.SimpleNamespace(sample_rate=sample_rate)
    _THREAD.records = []

    # 1) monkey‑patch tensor matmuls
    patched = []
    for name in ("mm", "matmul", "bmm"):
        orig = getattr(torch, name)
        setattr(torch, name, _wrap_fn(orig, name, cfg))
        patched.append((torch, name, orig))


    def _linear_hook(module, inputs, output):
      # this hook fires for EVERY module; only act on Linear
      if isinstance(module, nn.Linear):
          _record("linear",
                  inputs[0],            # in activations
                  module.weight.t(),    # weight matrix
                  output,               # out activations
                  cfg)

    hook_handle = register_module_forward_hook(_linear_hook)

    try:
        yield _THREAD.records     # user code runs here
    finally:
        for tgt, name, orig in patched:
            setattr(tgt, name, orig)
        hook_handle.remove()


In [8]:
import torch, torch.nn as nn, torch.nn.functional as F

class Tiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(128, 64)
        self.l2 = nn.Linear(64, 32)
    def forward(self, x):
        return self.l2(F.relu(self.l1(x)))

model = Tiny().cuda()
x = torch.randn(4, 128, device="cuda")

with verification(sample_rate=1.0) as recs:
    y = model(x)

print("records:", len(recs))
print(recs[0].keys())


records: 2
dict_keys(['op', 'ain', 'aout', 'hin', 'hout', 'hash_ms'])


In [9]:
def timeit(fn, *a, iters=50, **kw):
    import time, torch
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn(*a, **kw)
    torch.cuda.synchronize()
    return (time.perf_counter()-t0)*1000/iters

# baseline
base_ms = timeit(model, x, iters=100)

# with POMM
with verification(sample_rate=1.0):
    pomm_ms = timeit(model, x, iters=100)

print(f"baseline {base_ms:.3f} ms | with-hash {pomm_ms:.3f} ms | overhead {(pomm_ms-base_ms)/base_ms*100:5.1f}%")


baseline 0.267 ms | with‑hash 0.583 ms | overhead 118.8%


In [10]:
!pip -q install transformers sentencepiece

Running this using GPT-2

In [15]:
import torch, time, transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
# from pomm import verification   what we would use if pomm was a package

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

model_name = "gpt2"
tokenizer  = AutoTokenizer.from_pretrained(model_name)
model      = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()

BATCH   = 16
SEQ_LEN = 32

prompt_text = "The quick brown fox jumps over the lazy dog. " * 4
tokens  = tokenizer(prompt_text, return_tensors="pt")["input_ids"][0][:SEQ_LEN]
inputs  = tokens.unsqueeze(0).repeat(BATCH, 1).to(device)   # (B, L)

Device: cuda


In [16]:
@torch.no_grad()
def timed_forward(model, inputs, iters=20):
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        _ = model(inputs).logits
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000 / iters   # ms / batch

# baseline
base_ms = timed_forward(model, inputs, iters=20)
print(f"Baseline  (no POMM): {base_ms:6.2f} ms / batch")

# using the verification
with verification(sample_rate=1.0) as records:
    pomm_ms = timed_forward(model, inputs, iters=20)

print(f"With POMM (sha256): {pomm_ms:6.2f} ms / batch")
print(f"Overhead: {(pomm_ms - base_ms) / base_ms * 100:5.1f}%")

print(f"\nHash records collected: {len(records)}")
print("First record example:")
print({k: v for k, v in records[0].items() if k not in ('hin','hout')})


Baseline  (no POMM):  41.12 ms / batch
With POMM (sha256): 506.18 ms / batch
Overhead: 1130.8%

Hash records collected: 20
First record example:
{'op': 'linear', 'ain': torch.Size([16, 32, 768]), 'aout': torch.Size([16, 32, 50257]), 'hash_ms': 537.5580359998366}
