In [1]:
import torch
import threading
import types
import contextlib
from functools import wraps
from torch import nn
from torch.nn import Module
from torch.nn.modules.module import register_module_forward_hook
import random
import time
import socket  # TCP sockets
import struct # pack/unpack binary headers

# Toy example (no streaming)

In [2]:
# seed setup
torch.manual_seed(42)
random.seed(42)
torch.set_default_dtype(torch.float64)

# WORK FROM PROVER

# B is known to both parties
# A is only known to the prover
A = torch.randn(4000,4000) # activations
B = torch.randn(4000,4000) # public weight matrix

# full matmul
C = A@B

print("C.shape: ",C.shape)

n_rows = C.shape[0]
n_cols=C.shape[1]

# amount of rows and columns to sample
n_row_samples = int(n_rows*0.001)
n_col_samples = int(n_cols * 0.01)

# sampling random row and column indices
row_indices=torch.randperm(n_rows)[:n_row_samples]
col_indices = torch.randperm(n_cols)[:n_col_samples]

# sampling the rows of matrices to be sent over
sampled_A = A[row_indices]
sampled_C = C[row_indices][:,col_indices] # VALUES TO BE CHECKED.


print("shape of sampled_A: ",sampled_A.shape)
print("sampled_C.shape: ",sampled_C.shape)

# VERIFIER SIDE RECOMPUTATION
sampled_B = B[:,col_indices]

mat = sampled_A @ sampled_B
print("shape of matrix recomputation: ",mat.shape)

passed = torch.allclose(mat, sampled_C)
max_diff = (mat - sampled_C).abs().max()
print("max_diff: ",max_diff)
print(passed)


C.shape:  torch.Size([4000, 4000])
shape of sampled_A:  torch.Size([4, 4000])
sampled_C.shape:  torch.Size([4, 40])
shape of matrix recomputation:  torch.Size([4, 40])
max_diff:  tensor(4.9738e-14)
True


# With Streaming

In [19]:
HOST = "127.0.0.1"
PORT = 11234

In [20]:
def prepare_packet(sampled_A: torch.Tensor, sampled_C: torch.Tensor) -> bytes:
  """
  1. copy gpu tensors to cpu pinned memory
  2. extract shapes to include a tiny header
  3. return (header_bytes, raw_bytes_A, raw_bytes_C)
  """
  A_cpu = sampled_A.detach().cpu().pin_memory()
  C_cpu = sampled_C.detach().cpu().pin_memory()

  n_rows, k_dim = A_cpu.shape
  _, n_cols = C_cpu.shape

  header = struct.pack("<III",n_rows,n_cols,k_dim)

  raw_A = A_cpu.numpy().tobytes()
  raw_C = C_cpu.numpy().tobytes()

  return header, raw_A, raw_C

A = torch.randn(10, 16, device="cuda")
B = torch.randn(16, 8,  device="cuda")

C = A @ B
# sample a few rows/cols for demo
rows = torch.arange(3)
cols = torch.arange(4)
sampled_A = A[rows]
sampled_C = C[rows][:, cols]

hdr, bufA, bufC = prepare_packet(sampled_A, sampled_C)
print("Header bytes:", len(hdr), "=> shapes", struct.unpack("<III", hdr))
print("bufA:", len(bufA), "bytes;", "bufC:", len(bufC), "bytes")

Header bytes: 12 => shapes (3, 4, 16)
bufA: 384 bytes; bufC: 96 bytes


In [21]:
def send_packet(header: bytes, raw_A: bytes, raw_C: bytes, chunk_size: int = 1048576):
  """
  open tcp connection to (host, port), then
  1. send 12-byte header
  2. stream raw_A in chunk_size-byte slices
  3. stream raw_C in chunk_size-byte slices
  """

  conn = socket.create_connection((HOST,PORT))
  try:
    conn.sendall(header)
    total_A = len(raw_A)
    offset = 0
    while offset < total_A:
      end = offset + chunk_size
      conn.sendall(raw_A[offset:end])
      offset = end

    total_C = len(raw_C)
    offset = 0
    while offset < total_C:
      end = offset + chunk_size
      conn.sendall(raw_C[offset:end])
      offset = end
  finally:
    conn.close()

In [22]:
def new(A,B):
  C = A @ B
  n_rows, n_cols = C.shape

  # amount of rows and columns to sample
  n_row_samples = int(n_rows*0.001)
  n_col_samples = int(n_cols * 0.01)

  # sampling random row and column indices
  row_indices=torch.randperm(n_rows)[:n_row_samples]
  col_indices = torch.randperm(n_cols)[:n_col_samples]

  # sampling the rows of matrices to be sent over
  sampled_A = A[row_indices]
  sampled_C = C[row_indices][:,col_indices] # VALUES TO BE CHECKED.


  # INCLUDE ALL THE STREAMING LOGIC HERE
  t_prep_start = time.perf_counter()
  header, bufA, bufC = prepare_packet(sampled_A, sampled_C)
  prep_ms = (time.perf_counter() - t_prep_start) * 1000
  print(f"prepare_packet: {prep_ms:.2f} ms")

  t_send_start = time.perf_counter()
  send_packet(header, bufA, bufC)
  send_ms = (time.perf_counter() - t_send_start) * 1000
  print(f"send_packet: {send_ms:.2f} ms")
  print(f"total network overhead: {(prep_ms + send_ms):.2f} ms")

In [23]:
A = torch.randn(400,4000) # activations
B = torch.randn(4000,400) # public weight matrix

def time_matmuls(iters: int) -> float:
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for i in range(iters):
      _ = A @ B
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0   # ms

def time_sampling_matmuls(iters: int) -> float:
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for i in range(iters):
      new(A,B)
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0   # ms

baseline = time_matmuls(10)
print(baseline, "ms")

new_time = time_sampling_matmuls(10)
print(new_time, "ms")

overhead = new_time - baseline
print(f"overhead (%): {overhead/baseline:}%")

241.50785700021515 ms
prepare_packet: 0.13 ms
send_packet: 0.28 ms
total network overhead: 0.41 ms
[server] connection from ('127.0.0.1', 58712)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
[server] closed connection
prepare_packet: 0.11 ms
send_packet: 0.20 ms
total network overhead: 0.31 ms
[server] connection from ('127.0.0.1', 58718)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
[server] closed connection
prepare_packet: 0.11 ms
[server] connection from ('127.0.0.1', 58730)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
send_packet: 0.31 ms
total network overhead: 0.42 ms
[server] closed connection
prepare_packet: 0.13 ms
[server] connection from ('127.0.0.1', 58746)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
send_packet: 0.30 ms
total network overhead: 0.44 ms
[server] closed connection
prepare_packet: 0.14 ms
send_packet: 0.19 ms
total network

In [38]:
import threading, socket, struct

HOST, PORT = "127.0.0.1", 11234

def robust_server():
    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind((HOST, PORT))
    srv.listen(1)
    print("[server] listening on", HOST, PORT)

    while True:
        conn, addr = srv.accept()
        print("[server] connection from", addr)
        try:
            # 1) read header (we packed rows, cols, k_dim)
            hdr = conn.recv(12)
            n_rows, n_cols, k_dim = struct.unpack("<III", hdr)
            print(f"[server] header → rows={n_rows}, cols={n_cols}, k_dim={k_dim}")

            # 2) read exactly A_bytes then C_bytes
            # 8 because float64
            A_bytes = n_rows * k_dim * 8
            C_bytes = n_rows * n_cols * 8

            buf = bytearray()
            while len(buf) < A_bytes + C_bytes:
                chunk = conn.recv((A_bytes + C_bytes) - len(buf))
                if not chunk:
                    raise RuntimeError("connection closed early")
                buf.extend(chunk)
            print(f"[server] received payload: {len(buf)} bytes")

            # (we're just discarding it here; a real verifier would reshape & check)

        except Exception as e:
            print("[server] error during handling:", e)
        finally:
            conn.close()
            print("[server] closed connection")

# start it once, in daemon mode
threading.Thread(target=robust_server, daemon=True).start()


Exception in thread Thread-11 (robust_server):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner


In [25]:
_THREAD = threading.local()

_orig_matmul         = torch.matmul
_orig_tensor_matmul  = torch.Tensor.__matmul__
_orig_tensor_rmatmul = torch.Tensor.__rmatmul__

def _make_warpper(orig_function):
  """
  returns a wrapper function around orig_function that:
  1. bypasses hooking if _THREAD.no_hook is True
  2. sets flag to avoid recursive hooks.
  3. calls orig_function to get the real results
  4. invokes new(a,b) streaming logic
  """
  @wraps(orig_function)
  def wrapper(a,b,*args,**kwargs):
    # if already inside a hook, just do the raw operaetion
    if getattr(_THREAD, "no_hook", False):
      return orig_function(a,b,*args,**kwargs)

    # raise flag so nested matmuls aren't hooked
    _THREAD.no_hook = True
    try:
      out = orig_function(a,b,*args,**kwargs)
      new(a,b)

    finally:
      _THREAD.no_hook = False

    return out
  return wrapper

In [26]:
def _linear_forward_hook(module: Module, inputs: tuple, output: torch.Tensor):
  """
  called after every nn.linear.forward.
  1. skip if inside another hook
  2. pull out the inpute activations and weight matrix and call new() to stream the sampled slice
  """
  if getattr(_THREAD,"no_hook",False):
    return

  _THREAD.no_hook = True
  try:
    inp = inputs[0]
    weight = module.weight.t()
    new(inp,weight)
  finally:
    _THREAD.no_hook = False

In [27]:
@contextlib.contextmanager
def streaming_audit():
  torch.matmul = _make_warpper(torch.matmul)
  torch.Tensor.__matmul__ = _make_warpper(torch.Tensor.__matmul__)
  torch.Tensor.__rmatmul__ = _make_warpper(torch.Tensor.__rmatmul__)

  hook_handle = register_module_forward_hook(_linear_forward_hook)

  try:
    yield # returns to user code with hooks active
  finally:
    # unpatch everything
    torch.matmul              = _orig_matmul
    torch.Tensor.__matmul__   = _orig_tensor_matmul
    torch.Tensor.__rmatmul__  = _orig_tensor_rmatmul

    # 4) Remove the forward-hook
    hook_handle.remove()


In [29]:
A = torch.randn(4000, 4000, device="cuda")
B = torch.randn(4000, 4000, device="cuda")

with streaming_audit():
    # Each of these operations will:
    # 1) compute the result normally on GPU
    # 2) call new(A, B) under the hood to stream samples
    _ = A @ B
    _ = torch.matmul(A, B)
    # If you have an nn.Linear layer:
    lin = torch.nn.Linear(4000, 4000).cuda()
    _ = lin(A)

prepare_packet: 0.23 ms
[server] connection from ('127.0.0.1', 56108)
[server] header → rows=4, cols=40, k_dim=4000
[server] received payload: 64640 bytes
send_packet: 0.59 ms
total network overhead: 0.82 ms
[server] closed connection
prepare_packet: 0.26 ms
[server] connection from ('127.0.0.1', 56118)
[server] header → rows=4, cols=40, k_dim=4000
send_packet: 0.61 ms
total network overhead: 0.88 ms
[server] received payload: 64640 bytes
[server] closed connection
prepare_packet: 0.27 ms
[server] connection from ('127.0.0.1', 48842)
[server] header → rows=4, cols=40, k_dim=4000
[server] received payload: 64640 bytes
send_packet: 0.64 ms
total network overhead: 0.91 ms
[server] closed connection


In [41]:
A = torch.randn(4000, 4000, device="cuda")
B = torch.randn(4000, 4000, device="cuda")

with streaming_audit():
    # Each of these operations will:
    # 1) compute the result normally on GPU
    # 2) call new(A, B) under the hood to stream samples
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for i in range(10):
      torch.matmul(A,B)
    torch.cuda.synchronize()
    print((time.perf_counter() - t0) * 1000.0)   # ms


prepare_packet: 0.33 ms
[server] connection from ('127.0.0.1', 56474)
[server] header → rows=4, cols=40, k_dim=4000
send_packet: 0.60 ms
total network overhead: 0.94 ms
[server] received payload: 64640 bytes
[server] closed connection
prepare_packet: 0.26 ms
[server] connection from ('127.0.0.1', 56476)
[server] header → rows=4, cols=40, k_dim=4000
[server] received payload: 64640 bytes
send_packet: 0.56 ms
total network overhead: 0.82 ms
[server] closed connection
prepare_packet: 0.30 ms
[server] connection from ('127.0.0.1', 56484)
[server] header → rows=4, cols=40, k_dim=4000
[server] received payload: 64640 bytes
[server] closed connection


ConnectionResetError: [Errno 104] Connection reset by peer

In [36]:
A = torch.randn(400,4000) # activations
B = torch.randn(4000,400) # public weight matrix

def time_matmuls(iters: int) -> float:
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for i in range(iters):
      _orig_matmul(A,B)
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0   # ms

def time_sampling_matmuls(iters: int) -> float:
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for i in range(iters):
      new(A,B)
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0   # ms

baseline = time_matmuls(10)
print(baseline, "ms")

new_time = time_sampling_matmuls(10)
print(new_time, "ms")

overhead = new_time - baseline
print(f"overhead (%): {overhead/baseline:}%")

232.21993000061047 ms
prepare_packet: 0.12 ms
send_packet: 0.26 ms
total network overhead: 0.38 ms
[server] connection from ('127.0.0.1', 49368)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
[server] closed connection
prepare_packet: 0.15 ms
[server] connection from ('127.0.0.1', 49384)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
send_packet: 0.37 ms
total network overhead: 0.52 ms
[server] closed connection
prepare_packet: 0.18 ms
[server] connection from ('127.0.0.1', 49392)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
[server] closed connection
send_packet: 0.35 ms
total network overhead: 0.53 ms
prepare_packet: 0.11 ms
[server] connection from ('127.0.0.1', 49404)
[server] header → rows=0, cols=4, k_dim=4000
[server] received payload: 0 bytes
send_packet: 0.31 ms
total network overhead: 0.42 ms
[server] closed connection
prepare_packet: 0.11 ms
send_packet: 0.21 ms
total network