In [17]:
import socket, struct, threading, time, random, types, contextlib
import torch, numpy as np, torch.nn as nn
from functools import wraps
from torch.nn.modules.module import register_module_forward_hook



HOST, PORT_BASE, GLOBAL_SEED = "127.0.0.1", 11234, 42
random.seed(GLOBAL_SEED); torch.manual_seed(GLOBAL_SEED)

orig_matmul = torch.matmul
orig_tensor_matmul = torch.Tensor.__matmul__
orig_tensor_rmatmul = torch.Tensor.__rmatmul__



def _verifier_server(B_public, m, n, layer_idx, n_rows, n_cols):
    """Bare-bones TCP listener; exits after one blob."""
    def recvall(sock, n):
        buf = bytearray()
        while len(buf) < n:
            chunk = sock.recv(n - len(buf))
            if not chunk:
                raise RuntimeError("socket closed")
            buf.extend(chunk)
        return bytes(buf)

    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.bind((HOST, PORT_BASE + layer_idx))
    srv.listen(1)
    conn, _ = srv.accept()

    hdr = recvall(conn, 8)
    layer, rows, cols = struct.unpack("<HHH2x", hdr)
    print(f"[Verifier] Received header: layer_idx={layer}, n_rows={rows}, n_cols={cols}")

    k = B_public.shape[0]
    bytes_A = rows * k * 4
    bytes_C = rows * cols * 4
    payload = recvall(conn, bytes_A + bytes_C)
    print(f"[Verifier] Received payload: {len(payload)} bytes (A: {bytes_A}, C: {bytes_C})")

    A_rows = np.frombuffer(payload[:bytes_A], dtype=np.float32).reshape(rows, k)
    C_vals = np.frombuffer(payload[bytes_A:], dtype=np.float32).reshape(rows, cols)

    rng = random.Random(GLOBAL_SEED)
    row_idx_v = torch.tensor(rng.sample(range(m), n_rows))
    col_idx_v = torch.tensor(rng.sample(range(n), n_cols))

    print(f"[Verifier] row_idx shape: {row_idx_v.shape}, first few: {row_idx_v[:5]}")
    print(f"[Verifier] col_idx shape: {col_idx_v.shape}, first few: {col_idx_v[:5]}")

    # disable hooks **inside verifier** to avoid recursion
    _THREAD.no_hook = True
    try:
        B_sub      = B_public[:, col_idx_v]                    # k × cols
        recomputed = orig_tensor_matmul(torch.from_numpy(A_rows.copy()), B_sub)
    finally:
        _THREAD.no_hook = False

    diff = torch.abs(recomputed - torch.from_numpy(C_vals.copy()))
    ok = torch.allclose(recomputed, torch.from_numpy(C_vals), atol=1e-3, rtol=1e-3)

    print(f"[Verifier] Max diff: {diff.max().item()}, mean diff: {diff.mean().item()}")
    print(f"[Verifier] layer {layer} passed? {ok}\n")
    conn.close(); srv.close()


def _prover_send(sampled_A, sampled_C, layer_idx):
    n_rows = sampled_A.shape[0]
    n_cols = sampled_C.shape[1]
    print(f"[Prover] Sending: n_rows={n_rows}, n_cols={n_cols}")

    buf_A = sampled_A.cpu().numpy().astype(np.float32).tobytes()
    buf_C = sampled_C.cpu().numpy().astype(np.float32).tobytes()
    hdr = struct.pack("<HHH2x", layer_idx, n_rows, n_cols)

    with socket.create_connection((HOST, PORT_BASE + layer_idx)) as s:
        s.sendall(hdr + buf_A + buf_C)
        print(f"[Prover] Sent {len(hdr)+len(buf_A)+len(buf_C)} total bytes\n")


def audit_protocol(A, B, layer_idx):
    if getattr(_THREAD,"in_audit",False):
        return


    _THREAD.in_audit = True
    try:
        # use the *un-patched* implementation exactly once
        C = orig_matmul(A, B)

        m, _ = A.shape
        n    = B.shape[1]

        rng = random.Random(GLOBAL_SEED)
        row_idx = torch.tensor(rng.sample(range(m), max(1, int(m * 0.001))))
        col_idx = torch.tensor(rng.sample(range(n), max(1, int(n * 0.01))))

        sampled_A = A[row_idx]
        sampled_C = C[row_idx][:, col_idx]

        print(f"[Prover] row_idx shape: {row_idx.shape}, first few: {row_idx[:5]}")
        print(f"[Prover] col_idx shape: {col_idx.shape}, first few: {col_idx[:5]}")
        print(f"[Prover] sampled_A shape: {sampled_A.shape}")
        print(f"[Prover] sampled_C shape: {sampled_C.shape}")

        th = threading.Thread(
            target=_verifier_server,
            args=(B, m, n, layer_idx,
                  sampled_A.shape[0], sampled_C.shape[1]),
            daemon=True)
        th.start()
        time.sleep(0.05)
        _prover_send(sampled_A, sampled_C, layer_idx)
        th.join()
    finally:
        _THREAD.in_audit = False

# -----------------  monkey-patch + public context  ------------------
_THREAD = threading.local()

def _wrap_fn(fn, op_name, cfg):
    @wraps(fn)
    def wrapper(*args, **kw):
        if getattr(_THREAD, "no_hook", False):
            return fn(*args, **kw)          # bypass while flag is set
        out = fn(*args, **kw)
        if random.random() <= cfg.sample_rate:
            cfg.counter += 1
            audit_protocol(args[0], args[1], cfg.counter)
        return out
    return wrapper


@contextlib.contextmanager
def verification(sample_rate=0.1):
    cfg = types.SimpleNamespace(sample_rate=sample_rate, counter=0)
    _THREAD.records = []

    patched = []
    for name in ("mm", "matmul", "bmm"):
        orig = getattr(torch, name)
        setattr(torch, name, _wrap_fn(orig, name, cfg))
        patched.append((torch, name, orig))

     # ------------ patch tensor @-operator methods ------------------
    def _make_tensor_patch(orig_meth):
        @wraps(orig_meth)
        def _tensor_mm(self, other):
            if getattr(_THREAD, "no_hook", False):
                return orig_meth(self, other)
            out = orig_meth(self, other)
            if random.random() <= cfg.sample_rate:
                cfg.counter += 1
                audit_protocol(self, other, cfg.counter)
            return out
        return _tensor_mm

    for meth_name, orig_meth in (("__matmul__", orig_tensor_matmul),
                                 ("__rmatmul__", orig_tensor_rmatmul)):
        setattr(torch.Tensor, meth_name, _make_tensor_patch(orig_meth))
        patched.append((torch.Tensor, meth_name, orig_meth))

    def _linear_hook(module, inputs, output):
        if isinstance(module, nn.Linear) and random.random() <= cfg.sample_rate:
            cfg.counter += 1
            audit_protocol(inputs[0], module.weight.t(), cfg.counter)

    hook_handle = register_module_forward_hook(_linear_hook)

    try:
        yield
    finally:
        for tgt, name, orig in patched:
            setattr(tgt, name, orig)
        hook_handle.remove()


# ----------------------------  demo  --------------------------------
if __name__ == "__main__":
    A = torch.randn(10000, 10000)
    B = torch.randn(10000, 1000)

    with verification(sample_rate=1.0):
        _ = A @ B
        _ = torch.matmul(A,B)
        torch.matmul(torch.randn(10000,10000), torch.randn(10000,1000))

[Prover] row_idx shape: torch.Size([10]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([10]), first few: tensor([432,  32,  30,  95, 223])
[Prover] sampled_A shape: torch.Size([10, 10000])
[Prover] sampled_C shape: torch.Size([10, 10])
[Prover] Sending: n_rows=10, n_cols=10
[Verifier] Received header: layer_idx=1, n_rows=10, n_cols=10
[Prover] Sent 400408 total bytes

[Verifier] Received payload: 400400 bytes (A: 400000, C: 400)
[Verifier] row_idx shape: torch.Size([10]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([10]), first few: tensor([432,  32,  30,  95, 223])
[Verifier] Max diff: 3.0517578125e-05, mean diff: 5.185604095458984e-06
[Verifier] layer 1 passed? True

[Prover] row_idx shape: torch.Size([10]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([10]), first few: tensor([432,  32,  30,  95, 223])
[Prover] sampled_A shape: torch.Size([10, 10000])
[Prover] s

In [26]:
import time, torch

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

d = 8192
A = torch.randn(d, d, device=device)
B = torch.randn(d, d // 2, device=device)    # non-square keeps sizes realistic

def time_one(fn):
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000  # ms

# warm-up
for _ in range(5): (A @ B)

baseline = time_one(lambda: (A @ B))

with verification(sample_rate=1.0):          # audit *every* matmul
    audited = time_one(lambda: (A @ B))

print(f"baseline: {baseline:.2f} ms  |  audited: {audited:.2f} ms  "
      f"|  overhead: {audited - baseline:.2f} ms  ({(audited/baseline-1)*100:.1f} %)")


Exception in thread Thread-70 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Verifier] Received header: layer_idx=1, n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
baseline: 204.62 ms  |  audited: 365.48 ms  |  overhead: 160.87 ms  (78.6 %)


In [30]:
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

d = 8192
A = torch.randn(d, d, device=device)
B = torch.randn(d, d // 2, device=device)

def _sync():
    if device == "cuda":
        torch.cuda.synchronize()

def time_matmuls(iters: int, *, sample_rate: float | None = None) -> float:
    ctx = verification(sample_rate=sample_rate) if sample_rate is not None else contextlib.nullcontext()

    _sync()
    t0 = time.perf_counter()
    with ctx:
        for _ in range(iters):
            _ = A @ B
    _sync()
    return (time.perf_counter() - t0) * 1000.0   # ms


iters = 100
baseline_ms = time_matmuls(iters)                    # no auditing
audited_ms  = time_matmuls(iters, sample_rate=0.10)  # 10 % audits

print(f"{iters} matmuls — baseline: {baseline_ms:.2f} ms | "
      f"audited (10 %): {audited_ms:.2f} ms | "
      f"overhead: {audited_ms - baseline_ms:.2f} ms "
      f"({(audited_ms / baseline_ms - 1) * 100:.1f} %)")


Exception in thread Thread-102 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Verifier] Received header: layer_idx=1, n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])


Exception in thread Thread-103 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Verifier] Received header: layer_idx=2, n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])


Exception in thread Thread-104 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Verifier] Received header: layer_idx=3, n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])


Exception in thread Thread-105 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Verifier] Received header: layer_idx=4, n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])


Exception in thread Thread-106 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Verifier] Received header: layer_idx=5, n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])


Exception in thread Thread-107 (_verifier_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-17-aabdc5875a37>", line 57, in _verifier_server
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


[Prover] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Prover] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
[Prover] sampled_A shape: torch.Size([8, 8192])
[Prover] sampled_C shape: torch.Size([8, 40])
[Prover] Sending: n_rows=8, n_cols=40
[Prover] Sent 263432 total bytes

[Verifier] Received header: layer_idx=6, n_rows=8, n_cols=40
[Verifier] Received payload: 263424 bytes (A: 262144, C: 1280)
[Verifier] row_idx shape: torch.Size([8]), first few: tensor([1824,  409, 4506, 4012, 3657])
[Verifier] col_idx shape: torch.Size([40]), first few: tensor([3456,  260,  244,  767, 1791])
100 matmuls — baseline: 15592.08 ms | audited (10 %): 17767.23 ms | overhead: 2175.15 ms (14.0 %)
