In [18]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import hashlib
import random
import json
import datetime
import time

In [2]:
def quantize(x: t.Tensor, num_bits: int = 8) -> t.Tensor:
  """
  Quantization using stochastic rounding.
    - Divides the tensor into 2 ** num_bits - 1 bins (2 ** num_bits possible vals)
      and randomly rounds to each val with probability proportional to distance from val
    - Maintains unbiasedness
  """
  x_min = x.min()
  x_max = x.max()

  if x_max == x_min:    # degenerate case, not likely unless size of tensor is 1
    return x.clone()

  bins = 2 ** num_bits - 1
  scale = (x_max - x_min) / bins

  x_scaled = (x - x_min) / scale
  x_floor = t.floor(x_scaled)
  x_rem = x_scaled - x_floor

  rnd = t.rand_like(x_rem)
  x_quantized = (x_floor + (rnd < x_rem).float()) * scale + x_min

  return x_quantized

In [3]:
t_test = t.rand(10) * 2 - 1     # from -1 to 1
print(t.sort(t_test).values)

t_quantized = quantize(t_test, num_bits=4)
print(t.sort(t_quantized).values)

tensor([-0.8523, -0.7502, -0.7030, -0.3144, -0.2169, -0.2010,  0.0906,  0.5037,
         0.5889,  0.6592])
tensor([-0.8523, -0.7515, -0.6508, -0.3485, -0.2477, -0.1469,  0.0546,  0.4577,
         0.6592,  0.6592])


In [4]:
def hash_tensor(x: t.Tensor) -> str:
  """
  Hash (SHA256) for a tensor.
    - Current implementation has a lot of overhead
    - First moving tensor to CPU, then converting to numpy, then hashing
    - Ideas of batching tensors to hash?
  """
  h = hashlib.sha256()
  h.update(x.detach().cpu().numpy().tobytes())
  return h.hexdigest()

In [5]:
print(hash_tensor(t_quantized))

fd40cbee873ec7d4a971476bd4abfd2f3e07bf14ea5d5bdd2e71dbc4ad151a80


In [6]:
class ProofLinear(nn.Module):
  """
  nn.Linear wrapper that logs proofs for the matrix multiplication.
    - hashes input when forward is called, then hashes output
    - also contained hashed weights for verification
  """
  def __init__(self,
               in_features: int,
               out_features: int,
               bias: bool = True,
               is_quantized: bool = True,
               num_bits: int = 8):
    super().__init__()

    self.linear = nn.Linear(in_features, out_features, bias=bias)
    if is_quantized:
      self.weight_quantized = quantize(self.linear.weight.data, num_bits=8)
    else:
      self.weight_quantized = self.linear.weight.data

    self.weight_hash = hash_tensor(self.weight_quantized)
    self.proof_records = []
    self.is_quantized = is_quantized
    self.num_bits = num_bits

  def forward(self, x: t.Tensor) -> t.Tensor:
    input_hash = hash_tensor(x.detach())

    y = self.linear(x)

    if self.is_quantized:
      x_to_hash = quantize(x, num_bits=self.num_bits)
      y_to_hash = quantize(y, num_bits=self.num_bits)
    else:
      x_to_hash, y_to_hash = x, y

    input_hash = hash_tensor(x_to_hash.detach())
    output_hash = hash_tensor(y_to_hash.detach())

    dic = {
      "module": "ProofLinear",
      "input": x.detach().clone(),
      "output": y.detach().clone(),
      "is_quantized": self.is_quantized
    }

    if self.is_quantized:
      dic["weight_quantized"] = self.weight_quantized.clone()
      dic["input_quantized"] = x_to_hash.clone()
      dic["output_quantized"] = y_to_hash.clone()

    dic["weight_hash"] = self.weight_hash
    dic["input_hash"] = input_hash
    dic["output_hash"] = output_hash

    self.proof_records.append(dic)

    return y

In [7]:
def wrap_model_linear(model: nn.Module, is_quantized: bool = True):
  """
  BFS way to replace all linear layers with ProofLinear layers
  - (haven't tested on more complex models, but should work)
  """
  queue = [model]

  while queue:
    parent = queue.pop(0)

    for name, child in list(parent.named_children()):
      if isinstance(child, t.nn.Linear):
        proof_linear = ProofLinear(
          child.in_features,
          child.out_features,
          bias=(child.bias is not None),
          is_quantized=is_quantized
        )

        proof_linear.linear.weight.data.copy_(child.weight.data)  # copying weight data
        if child.bias is not None:
          proof_linear.linear.bias.data.copy_(child.bias.data)

        setattr(parent, name, proof_linear)
      else:
          queue.append(child)

In [8]:
def get_all_proof_records(model: nn.Module):
  """
  Getting all proof records from the model (as each ProofLinear has the proof_records attribute)
  """
  records = []
  for m in model.modules():
    if hasattr(m, "proof_records"):
      records.extend(m.proof_records)

  return records

In [9]:
class TensorEncoder(json.JSONEncoder):
  """
  Helper class to serialize tensors to JSON
  """
  def default(self, obj):
    if isinstance(obj, t.Tensor):
      return obj.cpu().tolist()

    return super().default(obj)

In [10]:
def verify_proof(model: nn.Module, sample_count_pct: int = 0.2, print_records: bool = False) -> bool:
  """
  Gathers all proof records from the model, randomly samples some percent, recomputes hashes
  and verifies them against the stored values
  """
  proof_records = get_all_proof_records(model)
  if not proof_records:
    print("No proof records.")
    return False

  sample_indices = random.sample(
    range(len(proof_records)),
    int(sample_count_pct * len(proof_records))
  )
  all_verified = True

  for idx in sample_indices:
    record = proof_records[idx]

    # recomputing hashes
    att = "_quantized" if record["is_quantized"] else ""
    recomputed_input_hash = hash_tensor(record[f"input{att}"])
    recomputed_weight_hash = hash_tensor(record[f"weight{att}"])
    recomputed_output_hash = hash_tensor(record[f"output{att}"])

    proof_records[idx]["recomputed_weight_hash"] = recomputed_weight_hash
    proof_records[idx]["recomputed_input_hash"] = recomputed_input_hash
    proof_records[idx]["recomputed_output_hash"] = recomputed_output_hash

    proof_records[idx]["verified_weight_hash"] = record["weight_hash"] == recomputed_weight_hash
    proof_records[idx]["verified_input_hash"] = record["input_hash"] == recomputed_input_hash
    proof_records[idx]["verified_output_hash"] = record["output_hash"] == recomputed_output_hash

    if not (proof_records[idx]["verified_weight_hash"] and
      proof_records[idx]["verified_input_hash"]        and
      proof_records[idx]["verified_output_hash"]):

      all_verified = False

  if print_records: # dumps proof records to json file
    name = "NA"
    if hasattr(model, "name"):
      name = model.name

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    with open(f"proof_records_{name}_{timestamp}.json", "w") as f:
      json.dump(proof_records, f, cls=TensorEncoder, indent=4)

  return all_verified

In [11]:
class DummyModel(nn.Module):
  """
  Mock dummy model for testing
  """
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(128, 64)
    self.linear2 = nn.Linear(64, 32)
    self.linear3 = nn.Linear(32, 16)
    self.linear4 = nn.Linear(16, 8)

  def forward(self, x: t.Tensor) -> t.Tensor:
    x = self.linear1(x)
    x = F.relu(x)
    x = self.linear2(x)
    x = F.relu(x)
    x = self.linear3(x)
    x = F.relu(x)
    x = self.linear4(x)

    return x

In [12]:
test_model = DummyModel()
test_model.name = "test"

wrap_model_linear(test_model)

test_input = t.randn(1, 128)
test_output = test_model(test_input)

verification_passed = verify_proof(test_model, sample_count_pct=1, print_records=True)

print(verification_passed)

True


In [33]:
def gpu_time_ms(fn, *args, **kwargs):
    """
    run a CUDA kernel or tensor op and return elapsed time in milliseconds.
    falls back to perf_counter on CPU tensors.
    """
    if args and t.is_tensor(args[0]) and args[0].is_cuda:
        # print("on gpu")
        start, end = t.cuda.Event(enable_timing=True), t.cuda.Event(enable_timing=True)
        t.cuda.synchronize()
        start.record()
        out = fn(*args, **kwargs)
        end.record()
        t.cuda.synchronize()
        return out, start.elapsed_time(end)
    else:
        print("on cpu")
        t0 = time.perf_counter()
        out = fn(*args, **kwargs)
        return out, (time.perf_counter() - t0)*1000

class ProofLinearTimed(nn.Module):
    """
    nn.Linear wrapper that:
    • does the original proof logging (weights / inputs / outputs + hashes)
    • records per‑call timing:
        - matmul_ms : time spent in self.linear(x)
        - hash_ms   : time spent in CPU SHA‑256 hashing
    """
    def __init__(self, in_features, out_features, bias=True,
                 is_quantized=True, num_bits=8):
        super().__init__()
        self.linear       = nn.Linear(in_features, out_features, bias=bias)
        self.is_quantized = is_quantized
        self.num_bits     = num_bits

        self.weight_quantized = quantize(self.linear.weight.data, num_bits) \
                                if is_quantized else self.linear.weight.data
        self.weight_hash      = hash_tensor(self.weight_quantized)
        self.proof_records    = []

    def forward(self, x: t.Tensor) -> t.Tensor:
        #  measure the matrix multiplication
        y, matmul_ms = gpu_time_ms(self.linear, x)

        if self.is_quantized:
            x_to_hash = quantize(x, self.num_bits)
            y_to_hash = quantize(y, self.num_bits)
        else:
            x_to_hash, y_to_hash = x, y

        #  measure hashing time
        t0 = time.perf_counter()
        input_hash  = hash_tensor(x_to_hash.detach())
        output_hash = hash_tensor(y_to_hash.detach())
        hash_ms = (time.perf_counter() - t0) * 1000

        rec = dict(
            module          = "ProofLinearTimed",
            input           = x.detach().clone(),
            output          = y.detach().clone(),
            is_quantized    = self.is_quantized,
            weight_hash     = self.weight_hash,
            input_hash      = input_hash,
            output_hash     = output_hash,
            matmul_ms       = matmul_ms,
            hash_ms         = hash_ms,
        )
        if self.is_quantized:
            rec.update(
                weight_quantized = self.weight_quantized.clone(),
                input_quantized  = x_to_hash.clone(),
                output_quantized = y_to_hash.clone()
            )
        self.proof_records.append(rec)
        return y

def wrap_model_linear_timed(model: nn.Module, is_quantized=True):
    queue = [model]
    while queue:
        parent = queue.pop(0)
        for name, child in list(parent.named_children()):
            if isinstance(child, nn.Linear):
                pl = ProofLinearTimed(child.in_features, child.out_features,
                                      bias=(child.bias is not None),
                                      is_quantized=is_quantized)
                pl.linear.weight.data.copy_(child.weight.data)
                if child.bias is not None:
                    pl.linear.bias.data.copy_(child.bias.data)
                setattr(parent, name, pl)
            else:
                queue.append(child)

def summarize_timing(model: nn.Module):
    records = get_all_proof_records(model)
    if not records:
        print("‑ no timing data recorded ‑")
        return
    total_mm  = sum(r["matmul_ms"] for r in records)
    total_hash= sum(r["hash_ms"]   for r in records)
    print(f"Total matmul time : {total_mm:9.3f} ms")
    print(f"Total hash  time  : {total_hash:9.3f} ms")
    print(f"Hashing overhead  : {100*total_hash/(total_mm+1e-6):6.2f} %")
    print("\nPer‑layer breakdown (ms):")
    for i, r in enumerate(records):
        print(f"  layer {i:02d}  matmul={r['matmul_ms']:7.3f}   hash={r['hash_ms']:7.3f}")

model = DummyModel()
wrap_model_linear_timed(model)
model=model.cuda()

x = t.randn(1, 128, device="cuda")
_ = model(x)

summarize_timing(model)

Total matmul time :     0.473 ms
Total hash  time  :     0.295 ms
Hashing overhead  :  62.33 %

Per‑layer breakdown (ms):
  layer 00  matmul=  0.209   hash=  0.081
  layer 01  matmul=  0.094   hash=  0.073
  layer 02  matmul=  0.086   hash=  0.070
  layer 03  matmul=  0.084   hash=  0.071


In [34]:
BATCHES       =  1        # mini‑batch size
FEATURES_IN   = 4096      # width of first layer
REPEATS       = 50        # number of forward passes

cuda = t.cuda.is_available()
device = "cuda" if cuda else "cpu"
print("Running on", device.upper())

class BiggerDummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(FEATURES_IN, 2048)
        self.l2 = nn.Linear(2048, 1024)
        self.l3 = nn.Linear(1024, 512)
        self.l4 = nn.Linear(512, 256)
    def forward(self,x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        return self.l4(x)

model = BiggerDummy()
wrap_model_linear_timed(model, is_quantized=True)
model = model.to(device)

# warm up
x = t.randn(BATCHES, FEATURES_IN, device=device)
_ = model(x)

for m in model.modules():
    if hasattr(m, "proof_records"):
        m.proof_records.clear()

for _ in range(REPEATS):
    x = t.randn(BATCHES, FEATURES_IN, device=device)
    _ = model(x)

summarize_timing(model)

def bench_one(shape, repeats=100):
    a = t.randn(*shape, device=device)
    b = t.randn(shape[-1], shape[-1]//2, device=device)

    _, gemm_ms = gpu_time_ms(t.mm, a, b)

    # measure hashing
    t0 = time.perf_counter()
    for _ in range(repeats):
        _ = hash_tensor(a)
    hash_ms = (time.perf_counter() - t0)*1000 / repeats

    print(f"Shape {shape}  GEMM {gemm_ms:6.3f} ms   hash {hash_ms:6.3f} ms   "
          f"hash/GEMM ≈ {hash_ms/gemm_ms:4.1f}×")

for shape in [(1,512), (1,4096), (32,4096),(64,8192)]:
    bench_one(shape)

Running on CUDA
Total matmul time :    24.220 ms
Total hash  time  :    23.502 ms
Hashing overhead  :  97.03 %

Per‑layer breakdown (ms):
  layer 00  matmul=  0.202   hash=  0.162
  layer 01  matmul=  0.192   hash=  0.158
  layer 02  matmul=  0.216   hash=  0.250
  layer 03  matmul=  0.188   hash=  0.150
  layer 04  matmul=  0.190   hash=  0.150
  layer 05  matmul=  0.190   hash=  0.150
  layer 06  matmul=  0.190   hash=  0.162
  layer 07  matmul=  0.193   hash=  0.158
  layer 08  matmul=  0.197   hash=  0.156
  layer 09  matmul=  0.214   hash=  0.144
  layer 10  matmul=  0.202   hash=  0.155
  layer 11  matmul=  0.205   hash=  0.175
  layer 12  matmul=  0.202   hash=  0.145
  layer 13  matmul=  0.202   hash=  0.170
  layer 14  matmul=  0.209   hash=  0.143
  layer 15  matmul=  0.199   hash=  0.143
  layer 16  matmul=  0.201   hash=  0.144
  layer 17  matmul=  0.200   hash=  0.164
  layer 18  matmul=  0.202   hash=  0.194
  layer 19  matmul=  0.197   hash=  0.178
  layer 20  matmul=  0

In [32]:
# benchmark GPU matmul vs. CPU‑hash for several batch sizes
import torch, time, pandas as pd

device      = "cuda" if torch.cuda.is_available() else "cpu"
FEATURES_IN = 4096        # width of first layer
REPEATS     = 30          # timed forward passes per batch size
BATCH_LIST  = [1, 2, 4, 8, 16, 32, 64, 128]

def run_once(batch_sz):
    model = BiggerDummy().to(device)
    wrap_model_linear_timed(model, is_quantized=True)

    # 1‑‑ warm‑up (ignored in stats)
    model = BiggerDummy()
    wrap_model_linear_timed(model, is_quantized=True)
    model = model.to(device)
    _ = model(torch.randn(batch_sz, FEATURES_IN, device=device))

    # clear records written by warm‑up
    for m in model.modules():
        if hasattr(m, "proof_records"):
            m.proof_records.clear()

    # 2 timed loop
    for _ in range(REPEATS):
        _ = model(torch.randn(batch_sz, FEATURES_IN, device=device))

    # 3 aggregate timings
    recs = get_all_proof_records(model)
    total_mm   = sum(r["matmul_ms"] for r in recs) / REPEATS
    total_hash = sum(r["hash_ms"]   for r in recs) / REPEATS
    overhead   = 100 * total_hash / (total_mm + 1e-6)
    return total_mm, total_hash, overhead

rows = []
for B in BATCH_LIST:
    mm, h, o = run_once(B)
    rows.append(dict(batch=B, matmul_ms=mm, hash_ms=h, overhead_pct=o))
    print(f"batch={B:>3}   GEMM={mm:6.3f} ms   hash={h:6.3f} ms   overhead={o:5.1f}%")

batch=  1   GEMM= 0.543 ms   hash= 0.537 ms   overhead= 98.8%
batch=  2   GEMM= 0.511 ms   hash= 0.640 ms   overhead=125.4%
batch=  4   GEMM= 0.610 ms   hash= 0.913 ms   overhead=149.7%
batch=  8   GEMM= 0.714 ms   hash= 1.457 ms   overhead=204.1%
batch= 16   GEMM= 0.800 ms   hash= 2.552 ms   overhead=318.9%
batch= 32   GEMM= 0.860 ms   hash= 6.964 ms   overhead=809.9%
batch= 64   GEMM= 1.196 ms   hash=14.517 ms   overhead=1213.4%
batch=128   GEMM= 1.579 ms   hash=20.192 ms   overhead=1279.0%
