In [1]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import hashlib
import random
import json
import datetime

In [2]:
def quantize(x: t.Tensor, num_bits: int = 8) -> t.Tensor:
  """
  Quantization using stochastic rounding.
    - Divides the tensor into 2 ** num_bits - 1 bins (2 ** num_bits possible vals)
      and randomly rounds to each val with probability proportional to distance from val
    - Maintains unbiasedness
  """
  x_min = x.min()
  x_max = x.max()

  if x_max == x_min:    # degenerate case, not likely unless size of tensor is 1
    return x.clone()

  bins = 2 ** num_bits - 1
  scale = (x_max - x_min) / bins

  x_scaled = (x - x_min) / scale
  x_floor = t.floor(x_scaled)
  x_rem = x_scaled - x_floor

  rnd = t.rand_like(x_rem)
  x_quantized = (x_floor + (rnd < x_rem).float()) * scale + x_min

  return x_quantized

In [3]:
t_test = t.rand(10) * 2 - 1     # from -1 to 1
print(t.sort(t_test).values)

t_quantized = quantize(t_test, num_bits=4)
print(t.sort(t_quantized).values)

tensor([-0.7076, -0.6072,  0.0858,  0.1010,  0.2929,  0.3091,  0.4286,  0.6552,
         0.9068,  0.9666])
tensor([-0.7076, -0.5960,  0.0737,  0.0737,  0.2969,  0.2969,  0.5201,  0.6317,
         0.9666,  0.9666])


In [None]:
def hash_tensor(x: t.Tensor) -> str:
  """
  Hash (SHA256) for a tensor.
    - Current implementation has a lot of overhead
    - First moving tensor to CPU, then converting to numpy, then hashing
    - Ideas of batching tensors to hash?
  """
  h = hashlib.sha256()
  h.update(x.detach().cpu().numpy().tobytes())
  return h.hexdigest()

In [None]:
print(hash_tensor(t_quantized))

32003dadb66b1cd55f7a882232c5a30857366896188027647d7878c7f4b53ecb


In [None]:
class ProofLinear(nn.Module):
  """
  nn.Linear wrapper that logs proofs for the matrix multiplication.
    - hashes input when forward is called, then hashes output
    - also contained hashed weights for verification
  """
  def __init__(self,
               in_features: int,
               out_features: int,
               bias: bool = True,
               is_quantized: bool = True,
               num_bits: int = 8):
    super().__init__()

    self.linear = nn.Linear(in_features, out_features, bias=bias)
    if is_quantized:
      self.weight_quantized = quantize(self.linear.weight.data, num_bits=8)
    else:
      self.weight_quantized = self.linear.weight.data

    self.weight_hash = hash_tensor(self.weight_quantized)
    self.proof_records = []
    self.is_quantized = is_quantized
    self.num_bits = num_bits

  def forward(self, x: t.Tensor) -> t.Tensor:
    input_hash = hash_tensor(x.detach())

    y = self.linear(x)

    if self.is_quantized:
      x_to_hash = quantize(x, num_bits=self.num_bits)
      y_to_hash = quantize(y, num_bits=self.num_bits)
    else:
      x_to_hash, y_to_hash = x, y

    input_hash = hash_tensor(x_to_hash.detach())
    output_hash = hash_tensor(y_to_hash.detach())

    dic = {
      "module": "ProofLinear",
      "input": x.detach().clone(),
      "output": y.detach().clone(),
      "is_quantized": self.is_quantized
    }

    if self.is_quantized:
      dic["weight_quantized"] = self.weight_quantized.clone()
      dic["input_quantized"] = x_to_hash.clone()
      dic["output_quantized"] = y_to_hash.clone()

    dic["weight_hash"] = self.weight_hash
    dic["input_hash"] = input_hash
    dic["output_hash"] = output_hash

    self.proof_records.append(dic)

    return y

In [None]:
def wrap_model_linear(model: nn.Module, is_quantized: bool = True):
  """
  BFS way to replace all linear layers with ProofLinear layers
  - (haven't tested on more complex models, but should work)
  """
  queue = [model]

  while queue:
    parent = queue.pop(0)

    for name, child in list(parent.named_children()):
      if isinstance(child, t.nn.Linear):
        proof_linear = ProofLinear(
          child.in_features,
          child.out_features,
          bias=(child.bias is not None),
          is_quantized=is_quantized
        )

        proof_linear.linear.weight.data.copy_(child.weight.data)  # copying weight data
        if child.bias is not None:
          proof_linear.linear.bias.data.copy_(child.bias.data)

        setattr(parent, name, proof_linear)
      else:
          queue.append(child)

In [None]:
def get_all_proof_records(model: nn.Module):
  """
  Getting all proof records from the model (as each ProofLinear has the proof_records attribute)
  """
  records = []
  for m in model.modules():
    if hasattr(m, "proof_records"):
      records.extend(m.proof_records)

  return records

In [None]:
class TensorEncoder(json.JSONEncoder):
  """
  Helper class to serialize tensors to JSON
  """
  def default(self, obj):
    if isinstance(obj, t.Tensor):
      return obj.cpu().tolist()

    return super().default(obj)

In [None]:
def verify_proof(model: nn.Module, sample_count_pct: int = 0.2, print_records: bool = False) -> bool:
  """
  Gathers all proof records from the model, randomly samples some percent, recomputes hashes
  and verifies them against the stored values
  """
  proof_records = get_all_proof_records(model)
  if not proof_records:
    print("No proof records.")
    return False

  sample_indices = random.sample(
    range(len(proof_records)),
    int(sample_count_pct * len(proof_records))
  )
  all_verified = True

  for idx in sample_indices:
    record = proof_records[idx]

    # recomputing hashes
    att = "_quantized" if record["is_quantized"] else ""
    recomputed_input_hash = hash_tensor(record[f"input{att}"])
    recomputed_weight_hash = hash_tensor(record[f"weight{att}"])
    recomputed_output_hash = hash_tensor(record[f"output{att}"])

    proof_records[idx]["recomputed_weight_hash"] = recomputed_weight_hash
    proof_records[idx]["recomputed_input_hash"] = recomputed_input_hash
    proof_records[idx]["recomputed_output_hash"] = recomputed_output_hash

    proof_records[idx]["verified_weight_hash"] = record["weight_hash"] == recomputed_weight_hash
    proof_records[idx]["verified_input_hash"] = record["input_hash"] == recomputed_input_hash
    proof_records[idx]["verified_output_hash"] = record["output_hash"] == recomputed_output_hash

    if not (proof_records[idx]["verified_weight_hash"] and
      proof_records[idx]["verified_input_hash"]        and
      proof_records[idx]["verified_output_hash"]):

      all_verified = False

  if print_records: # dumps proof records to json file
    name = "NA"
    if hasattr(model, "name"):
      name = model.name

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    with open(f"proof_records_{name}_{timestamp}.json", "w") as f:
      json.dump(proof_records, f, cls=TensorEncoder, indent=4)

  return all_verified

In [None]:
class DummyModel(nn.Module):
  """
  Mock dummy model for testing
  """
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(128, 64)
    self.linear2 = nn.Linear(64, 32)
    self.linear3 = nn.Linear(32, 16)
    self.linear4 = nn.Linear(16, 8)

  def forward(self, x: t.Tensor) -> t.Tensor:
    x = self.linear1(x)
    x = F.relu(x)
    x = self.linear2(x)
    x = F.relu(x)
    x = self.linear3(x)
    x = F.relu(x)
    x = self.linear4(x)

    return x

In [None]:
test_model = DummyModel()
test_model.name = "test"

wrap_model_linear(test_model)

test_input = t.randn(1, 128)
test_output = test_model(test_input)

verification_passed = verify_proof(test_model, sample_count_pct=1, print_records=True)

print(verification_passed)

True
