In [None]:
import torch
import random

# seed setup
torch.manual_seed(42)
random.seed(42)

# WORK FROM PROVER 

# Assume that this has already passed the 10% check
# if random.random() < 0.1:


# B is known to both parties
# A is only known to the prover
A = torch.randn(4000,4000) # activations
B = torch.randn(4000,4000) # public weight matrix

# full matmul
C = A@B

print("C.shape: ",C.shape)

n_rows = C.shape[0]
n_cols=C.shape[1]

# amount of rows and columns to sample
n_row_samples = int(n_rows*0.001)
n_col_samples = int(n_cols * 0.01)

# sampling random row and column indices
row_indices=torch.randperm(n_rows)[:n_row_samples]
col_indices = torch.randperm(n_cols)[:n_col_samples]

# MATRICES TO BE SENT OVER
sampled_A = A[row_indices] 
sampled_C = C[row_indices][:,col_indices] # VALUES TO BE CHECKED.


print("shape of sampled_A: ",sampled_A.shape)
print("sampled_C.shape: ",sampled_C.shape) 

# VERIFIER SIDE RECOMPUTATION
sampled_B = B[:,col_indices]

mat = sampled_A @ sampled_B
print(mat.shape)


passed = torch.allclose(mat, sampled_C)
print(passed)

C.shape:  torch.Size([4000, 4000])
shape of sampled_A:  torch.Size([4, 4000])
sampled_C.shape:  torch.Size([4, 40])
torch.Size([4, 40])
True


In [None]:
"""
Toy prover ↔ verifier stream for the long-context audit scheme.

 * One TCP connection per sampled matmul.
 * Payload contains ONLY
       - sampled_A  (fp16 rows)
       - sampled_C  (fp16 scalars)
       - layer_idx  (uint16 for now)
 * Indices are *not* sent; the verifier regenerates them with the same PRNG.

Run this file as-is: first it starts the verifier thread, then the prover
connects, pushes its bytes, and you should see “verification passed? True”.
"""

import socket, struct, threading, time, random, torch
import numpy as np

# ---------------- 0.  GLOBAL CONSTANTS & SEED  --------------------
HOST, PORT       = "127.0.0.1", 1234 # loopback → no firewall friction
GLOBAL_SEED      = 42                      # shared PRNG seed
LAYER_IDX        = 3                       # hard-coded for this toy demo

torch.manual_seed(GLOBAL_SEED)
random.seed(GLOBAL_SEED)

# ---------------- 1.  PROVER CALCULATES ITS SAMPLES  -------------
d = 4000
A = torch.randn(d, d)          # activations (private to prover)
B = torch.randn(d, d)          # weights (public / verifier also has)

C = A @ B                      # full matmul

# mask:   0.1 % rows , 1 % cols
rng = random.Random(GLOBAL_SEED)
row_idx = torch.tensor(rng.sample(range(d), int(d*0.001)))
col_idx = torch.tensor(rng.sample(range(d), int(d*0.01)))
# row_idx = torch.randperm(d)[: int(d * 0.001)]      # 4 rows
# col_idx = torch.randperm(d)[: int(d * 0.01)]       # 40 columns

sampled_A = A[row_idx]                             # 4 × 4000
sampled_C = C[row_idx][:, col_idx]                 # 4 × 40

# ---------------- 2.  VERIFIER SERVER THREAD  --------------------
def verifier_server(B_public, d):
    """
    Bare-bones TCP listener.
    Waits for one blob, checks it, prints result, then exits.
    """

    # ---- helper: recv exactly n bytes or die ----
    def recvall(sock, n):
        buf = bytearray()
        while len(buf) < n:
            chunk = sock.recv(n - len(buf))
            if not chunk:              # connection closed too early
                raise RuntimeError("socket closed")
            buf.extend(chunk)
        return bytes(buf)

    # ---- server main ----
    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.bind((HOST, PORT))
    srv.listen(1)
    conn, _ = srv.accept()

    # 2.1 read fixed 8-byte header   (uint16 layer | uint16 n_rows | uint16 n_cols | padding)
    hdr = recvall(conn, 8)
    layer_idx, n_rows, n_cols = struct.unpack("<HHH2x", hdr)

    # 2.2 compute how many bytes to expect for the two tensors
    bytes_A = n_rows * d * 2              # fp16 → 2 B each
    bytes_C = n_rows * n_cols * 2
    payload = recvall(conn, bytes_A + bytes_C)

    # 2.3 slice + reshape
    A_rows = np.frombuffer(payload[:bytes_A], dtype=np.float16).reshape(n_rows, d)
    C_vals = np.frombuffer(payload[bytes_A:], dtype=np.float16).reshape(n_rows, n_cols)

    # 2.4 regenerate PRNG mask (must match prover’s)
    rng = random.Random(GLOBAL_SEED)                  # same seed
    row_idx_v = torch.tensor(rng.sample(range(d), n_rows))
    col_idx_v = torch.tensor(rng.sample(range(d), n_cols))

    # 2.5 recompute C = A @ B  for the masked entries
    B_sub   = B_public[:, col_idx_v]                  # 4000 × 40
    recomputed = torch.from_numpy(A_rows.astype(np.float32)) @ B_sub  # 4 × 40

    ok = torch.allclose(recomputed.half(), torch.from_numpy(C_vals), atol=1e-3, rtol=1e-3)
    print(f"[Verifier] layer {layer_idx}  passed? {ok}")
    conn.close()
    srv.close()

# launch verifier thread
threading.Thread(target=verifier_server, args=(B, d), daemon=True).start()
time.sleep(0.1)          # give the listener a moment

# ---------------- 3.  PROVER SENDS ITS BLOB  ----------------------
def prover_send(sampled_A, sampled_C):
    n_rows, _ = sampled_A.shape
    _,      n_cols   = sampled_C.shape

    # 3.1 flatten to raw bytes (little-endian fp16)
    buf_A = sampled_A.cpu().numpy().astype(np.float16).tobytes()
    buf_C = sampled_C.cpu().numpy().astype(np.float16).tobytes()

    # 3.2 8-byte header: layer | n_rows | n_cols
    hdr = struct.pack("<HHH2x", LAYER_IDX, n_rows, n_cols)

    with socket.create_connection((HOST, PORT)) as s:
        s.sendall(hdr + buf_A + buf_C)    # one shot; OS chunks as needed

prover_send(sampled_A, sampled_C)

# tiny sleep so the verifier thread can print before script exits
time.sleep(0.2)


[Verifier] layer 3  passed? False


In [52]:
print(random.Random(GLOBAL_SEED))

<random.Random object at 0x13cf9ac20>


In [54]:
rng = random.Random(GLOBAL_SEED)
rng.random()

0.6394267984578837

In [55]:
rng = random.Random(GLOBAL_SEED)
rng.random()

0.6394267984578837