# Phase1: Dataset Registration and Verification

In [None]:
!pip -q install "transformers==4.55.0" "tokenizers==0.21.4" "pandas==2.2.2" "eth-hash[pycryptodome]==0.5.2" "requests==2.32.3"


In [None]:
# --- Extra installs for Web3 interaction ---
!pip install web3==6.19.0

**Phase1 Setup**

In [None]:
import os, gzip, shutil, json, time, zipfile
from pathlib import Path
import requests, pandas as pd
from eth_hash.auto import keccak
from transformers import AutoTokenizer

# Config (match your paper)
ZIP_URL   = "https://physionet.org/static/published-projects/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip"
ROW_LIMIT = 200
MODEL_REPO = "Qwen/Qwen2.5-3B-Instruct"   # tokenizer only (no model)
TARGET_FILE = "tokenized_admissions.jsonl"  # change if you want a different proof target

RAW_DIR   = Path("raw_data")
CANON_DIR = Path("canonical_data")
TOK_DIR   = Path("tokenized_data")
RAW_DIR.mkdir(exist_ok=True); CANON_DIR.mkdir(exist_ok=True); TOK_DIR.mkdir(exist_ok=True)


In [None]:
from urllib.parse import urlparse
from web3 import Web3

def infer_dataset_name(zip_url: str) -> str:
    fn = Path(urlparse(zip_url).path).name  # e.g., mimic-iv-clinical-database-demo-2.2.zip
    return fn[:-4] if fn.lower().endswith(".zip") else fn

# Use explicit name if you set DATASET_NAME; otherwise infer from ZIP_URL filename.
DATASET_NAME = globals().get("DATASET_NAME") or infer_dataset_name(ZIP_URL)

# Phase-1 computeDatasetId(string) == keccak256(abi.encodePacked(name))
dataset_id_bytes = Web3.keccak(text=DATASET_NAME)
dataset_id_hex   = "0x" + dataset_id_bytes.hex()

# Optional alias so older cells using the uppercase name still work
DATASET_ID_HEX = dataset_id_hex

# --- fix double 0x / whitespace ---
DATASET_ID_HEX = DATASET_ID_HEX.strip()
while DATASET_ID_HEX.lower().startswith("0x0x"):
    DATASET_ID_HEX = DATASET_ID_HEX[2:]

print("DATASET_NAME:", DATASET_NAME)
print("DATASET_ID_HEX:", DATASET_ID_HEX)


In [None]:
from eth_hash.auto import keccak

# --- Download ZIP once ---
zip_path = Path("mimic_demo.zip")
if not zip_path.exists():
    r = requests.get(ZIP_URL, stream=True); r.raise_for_status()
    with open(zip_path, "wb") as f:
        for chunk in r.iter_content(8192): f.write(chunk)

# --- Extract ZIP ---
if not any(RAW_DIR.iterdir()):
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(RAW_DIR)

# --- Canonicalize helpers (match paper) ---
def canonicalize_csv(in_path, out_path):
    try:
        df = pd.read_csv(in_path, dtype=str)
    except Exception as e:
        print("skip:", in_path.name, e); return False
    cols = sorted(df.columns)
    df = df[cols].sort_values(by=cols, na_position="last").reset_index(drop=True)
    df.to_csv(out_path, index=False, encoding="utf-8", lineterminator="\n")
    return True

# --- Decompress .csv.gz then canonicalize all CSVs (except demo_subject_id) ---
for fp in sorted(RAW_DIR.rglob("*")):
    if not fp.is_file():
        continue
    if "demo_subject_id" in fp.name.lower():
        continue
    if fp.suffix == ".gz" and fp.name.endswith(".csv.gz"):
        dec = fp.with_suffix("")  # drop .gz
        with gzip.open(fp, "rb") as fin, open(dec, "wb") as fout:
            shutil.copyfileobj(fin, fout)
        fp = dec
    if fp.suffix.lower() == ".csv":
        out = CANON_DIR / fp.name
        if canonicalize_csv(fp, out):
            print("canonicalized:", fp.name)

# --- Tokenize (Qwen tokenizer) ---
tok = AutoTokenizer.from_pretrained(MODEL_REPO)  # uses HF hub
def build_records(csv_path):
    df = pd.read_csv(csv_path, dtype=str).head(ROW_LIMIT)
    prompts, responses = [], []
    for _, row in df.iterrows():
        # Python dict repr (single quotes; NaN -> nan)
        prompts.append(
            f"You are a clinical assistant. Given the following record from {csv_path.name}, "
            f"provide a short summary:\n{row.to_dict()}"
        )
        responses.append("Summary: [Your summary here]")
    enc = tok(prompts, text_pair=responses, truncation=True, max_length=512)
    out_path = TOK_DIR / f"tokenized_{csv_path.stem}.jsonl"
    with open(out_path, "w", encoding="utf-8") as f:
        for i in range(len(prompts)):
            rec = {
                "prompt": prompts[i],
                "response": responses[i],
                "input_ids": list(map(int, enc["input_ids"][i])),
                "attention_mask": list(map(int, enc["attention_mask"][i])),
            }
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    return out_path

TOK_DIR.mkdir(exist_ok=True)
for csv in sorted(CANON_DIR.glob("*.csv")):
    out = build_records(csv); print("tokenized ->", out.name)

# --- Keccak(file bytes) for each JSONL ---
def file_keccak(p: Path):
    with open(p, "rb") as f:
        return keccak(f.read())

files  = sorted([p for p in TOK_DIR.iterdir() if p.suffix == ".jsonl"], key=lambda p: p.name)
leaves = [file_keccak(p) for p in files]

# === Merkle (SORTED-PAIR rule; duplicate last when odd) ===
def _parent(a: bytes, b: bytes) -> bytes:
    # match Solidity: if (a < b) hash(a||b) else hash(b||a)
    return keccak(a + b) if a < b else keccak(b + a)

def merkle_root(hashes: list[bytes]) -> bytes:
    level = hashes[:]
    if not level:
        raise ValueError("no leaves")
    while len(level) > 1:
        if len(level) % 2 == 1:
            level.append(level[-1])
        nxt = []
        for i in range(0, len(level), 2):
            nxt.append(_parent(level[i], level[i+1]))
        level = nxt
    return level[0]

def merkle_proof(hashes: list[bytes], idx: int) -> list[str]:
    proof, level_i, level = [], idx, hashes[:]
    if not (0 <= idx < len(level)):
        raise IndexError("target index out of range")
    while len(level) > 1:
        if len(level) % 2 == 1:
            level.append(level[-1])
        nxt = []
        for i in range(0, len(level), 2):
            L, R = level[i], level[i+1]
            if i == level_i or i+1 == level_i:
                sib = R if i == level_i else L
                proof.append("0x" + sib.hex())     # sibling only; no direction needed
                level_i = len(nxt)
            nxt.append(_parent(L, R))
        level = nxt
    return proof

# --- Root / proof for TARGET_FILE ---
root_bytes = merkle_root(leaves)
root       = "0x" + root_bytes.hex()

tgt_i      = next(i for i, f in enumerate(files) if f.name == TARGET_FILE)
proof      = merkle_proof(leaves, tgt_i)
leaf_hex   = "0x" + leaves[tgt_i].hex()

# --- Write artifacts for the paper & registration ---
with open("merkle_hashes.txt", "w") as f:
    f.write(f"Merkle Root: {root}\n\n")
    for p, h in zip(files, leaves):
        f.write(f"{p.name}: 0x{h.hex()}\n")
    f.write("\nProof for " + TARGET_FILE + ":\n")
    f.write("Leaf Hash: " + leaf_hex + "\n")
    f.write("Proof Array: " + str(proof) + "\n")

# Build metadata JSON for the dapp to upload to IPFS
import platform, time, json
# DATASET_NAME already defined elsewhere in your notebook; if not, set it:
DATASET_NAME = "MIMIC-IV-Demo-v2.2"
dataset_id_hex = "0x" + keccak(DATASET_NAME.encode("utf-8")).hex()

# Use the leaf hashes we already computed
file_hash_map = {p.name: "0x" + h.hex() for p, h in zip(files, leaves)}
total_size = sum(p.stat().st_size for p in files)
files_info = [{
    "filename": p.name,
    "size_bytes": p.stat().st_size,
    "keccak256": file_hash_map[p.name],
} for p in files]

preprocessing_metadata = {
    "dataset": {
        "name": DATASET_NAME,
        "dataset_id": dataset_id_hex,
        "description": "Tokenized clinical dataset prepared for Qwen2.5-3B-Instruct fine-tuning.",
        "merkle_root": root,
        "num_files": len(files_info),
        "total_size_bytes": total_size,
        "files": files_info,
        "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    },
    "tokenizer": { "name": MODEL_REPO },
    "prompt_template": {
        "description": "You are a clinical assistant. Given the following record from {filename}, provide a short summary.",
        "max_length": 512,
        "truncation": True
    },
    "environment": {
        "os": platform.system(),
        "python_version": platform.python_version(),
        "pandas_version": pd.__version__
    }
}

with open("preprocessing_metadata.json", "w", encoding="utf-8") as f:
    json.dump(preprocessing_metadata, f, indent=2)

# Optional: write a viewer-friendly verify bundle (no version yet)
verify_bundle = {
    "datasetId": dataset_id_hex,
    "merkleRoot": root,
    "targetFile": TARGET_FILE,
    "leafHash": leaf_hex,
    "proof": proof
}
with open("verify_bundle.json", "w", encoding="utf-8") as f:
    json.dump(verify_bundle, f, indent=2)

print("\n=== RESULTS ===")
print("Root:", root)
print("Target leaf:", leaf_hex)
print("Proof:", proof)
print("\nWrote: merkle_hashes.txt, preprocessing_metadata.json, verify_bundle.json, tokenized_data/")


In [None]:
# === Create a lightweight verify bundle for the dapp (robust; no chain dependency) ===
import os, json, time, string
from web3 import Web3

def _b32_hex(v: str) -> str:
    s = (v or "").strip()
    while s.lower().startswith("0x"):  # strip ANY number of 0x
        s = s[2:]
    s = "".join(s.split())             # remove spaces/newlines
    assert len(s) == 64 and all(c in string.hexdigits for c in s), f"bad bytes32: {v!r}"
    return "0x" + s.lower()

# Resolve datasetId: use provided hex or derive from a name
if not globals().get("DATASET_ID_HEX") or str(DATASET_ID_HEX).strip().lower() in ("", "0x", "auto"):
    assert globals().get("DATASET_NAME"), "Set DATASET_NAME or provide DATASET_ID_HEX"
    DATASET_ID_HEX = "0x" + Web3.keccak(text=DATASET_NAME).hex()[2:]
DATASET_ID_HEX = _b32_hex(DATASET_ID_HEX)

# Resolve dataset version WITHOUT touching chain
# - Prefer env/variable DATASET_VER if set, else default to 1 during local build.
DS_VERSION = int(os.getenv("DATASET_VER", str(globals().get("DATASET_VER", 1))))

ROW_LIMIT = 200  # disclosure: only first N rows were used

verify_bundle = {
    "datasetId": DATASET_ID_HEX,   # 0x… bytes32
    "version": DS_VERSION,         # uint256
    "leafHash": leaf_hex,          # 0x… bytes32 (target file’s keccak256)
    "proof": proof,                # ["0x…", ...]
    "merkleRoot": root,            # 0x… bytes32 (optional)
    "targetFile": TARGET_FILE,     # optional
    "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "methodology": {
        "row_limit": ROW_LIMIT,
        "note": f"Only the first {ROW_LIMIT} rows of each CSV were used during tokenization "
                "to create this dataset for demonstration purposes."
    }
}

with open("verify_bundle.json", "w") as f:
    json.dump(verify_bundle, f, indent=2)

print("Wrote verify_bundle.json with disclosure")
print("datasetId:", DATASET_ID_HEX, "version:", DS_VERSION)


# Phase1: Dataset Registration and Verification [Connect to SC]

**This step send the tx. Otherwise, it can be done via the UI**

In [None]:
from web3 import Web3
import json

# ==== CONFIG ====
RPC_URL = "Replace_with_your_URL"
PRIVATE_KEY = "Replace_with_you_PK"  # your wallet PK (testnet)
ACCOUNT = Web3.to_checksum_address(Web3().eth.account.from_key(PRIVATE_KEY).address)
REGISTRY_ADDR = Web3.to_checksum_address("DataRegistryAddress")  # deployed DatasetRegistry

# Load registry ABI (from your Solidity build)
registry_abi = json.loads("""<add_abi>""")

w3 = Web3(Web3.HTTPProvider(RPC_URL))
assert w3.is_connected(), "Web3 not connected"

registry = w3.eth.contract(address=REGISTRY_ADDR, abi=registry_abi)

# === Pin preprocessing_metadata.json to IPFS ===
PINATA_JWT = "Replace_with_your_JWT"
pinata_url = "https://api.pinata.cloud/pinning/pinFileToIPFS"

# Load and update preprocessing metadata
with open("preprocessing_metadata.json", "r", encoding="utf-8") as f:
    preprocessing_meta = json.load(f)

# Add disclosure of the row limit
ROW_LIMIT = 200
preprocessing_meta["methodology"] = {
    "row_limit": ROW_LIMIT,
    "note": f"Only the first {ROW_LIMIT} rows of each CSV were used during tokenization "
            "to create this dataset for demonstration purposes."
}

# Overwrite local file so we pin the updated one
with open("preprocessing_metadata.json", "w", encoding="utf-8") as f:
    json.dump(preprocessing_meta, f, indent=2, ensure_ascii=False)

# Pin to IPFS via Pinata
with open("preprocessing_metadata.json", "rb") as f:
    files = {"file": ("preprocessing_metadata.json", f)}
    headers = {"Authorization": f"Bearer {PINATA_JWT}"}
    resp = requests.post(pinata_url, files=files, headers=headers)
    resp.raise_for_status()
    metadata_cid = resp.json()["IpfsHash"]

print("Pinned metadata CID:", metadata_cid)

# === Register dataset on-chain ===
dataset_id_bytes32 = Web3.to_bytes(hexstr=dataset_id_hex)
merkle_root_bytes32 = Web3.to_bytes(hexstr=root)

nonce = w3.eth.get_transaction_count(ACCOUNT)
tx = registry.functions.registerDataset(
    dataset_id_bytes32,
    merkle_root_bytes32,
    f"ipfs://{metadata_cid}"
).build_transaction({
    "from": ACCOUNT,
    "nonce": nonce,
    "gas": 300000,
    "gasPrice": w3.to_wei("2", "gwei")
})

signed = w3.eth.account.sign_transaction(tx, PRIVATE_KEY)
tx_hash = w3.eth.send_raw_transaction(signed.rawTransaction)
print("registerDataset tx sent:", tx_hash.hex())
receipt = w3.eth.wait_for_transaction_receipt(tx_hash)
print("Status:", receipt.status)

# === Verify file membership ===
leaf_bytes32 = Web3.to_bytes(hexstr=leaf_hex)
proof_bytes = [Web3.to_bytes(hexstr=p) for p in proof]

is_member = registry.functions.verifyFileHash(
    dataset_id_bytes32,
    DS_VERSION,
    leaf_bytes32,
    proof_bytes
).call()

print(f"File {TARGET_FILE} membership verified on-chain:", is_member)


**[Extra Step]: Verification and Validation**

In [None]:
# === REBUILD CHECK: Ensure Merkle root matches on-chain ===
from eth_utils import keccak

def file_keccak(p):
    with open(p, "rb") as f:
        return keccak(f.read())

files  = sorted([p for p in TOK_DIR.iterdir() if p.suffix == ".jsonl"], key=lambda p: p.name)
leaves = [file_keccak(p) for p in files]

# Merkle build (sorted pairs, same as contract)
def merkle_root(leaves):
    if len(leaves) == 1:
        return leaves[0]
    if len(leaves) % 2 == 1:
        leaves.append(leaves[-1])
    new_level = []
    for i in range(0, len(leaves), 2):
        pair = sorted([leaves[i], leaves[i+1]])
        new_level.append(keccak(pair[0] + pair[1]))
    return merkle_root(new_level)

local_root = "0x" + merkle_root(leaves).hex()
print("Local root:   ", local_root)
print("On-chain root:", root)
assert local_root.lower() == root.lower(), "Mismatch! Something changed."


In [None]:
# === NEGATIVE TEST: Ensure wrong data fails verification ===
import os
fake_leaf = os.urandom(32)  # random bytes, not in dataset
fake_proof = proof  # any valid proof

is_member_fake = registry.functions.verifyFileHash(
    dataset_id_bytes32,
    DS_VERSION,
    fake_leaf,
    fake_proof
).call()

print("Expected False, got:", is_member_fake)
assert not is_member_fake, "Verifier incorrectly accepted wrong file!"


In [None]:
# === DATASET ID CROSS-CHECK ===
calc_onchain = registry.functions.computeDatasetId(DATASET_NAME).call()
calc_local   = keccak(DATASET_NAME.encode())
print("On-chain computeDatasetId:", calc_onchain.hex())
print("Local keccak of name:     ", calc_local.hex())
assert calc_onchain == calc_local, "Dataset ID mismatch!"


In [None]:
# === PIN verify_bundle.json for third-party verification ===
with open("verify_bundle.json", "rb") as f:
    files = {"file": ("verify_bundle.json", f)}
    headers = {"Authorization": f"Bearer {PINATA_JWT}"}
    resp = requests.post(pinata_url, files=files, headers=headers)
    resp.raise_for_status()
    verify_cid = resp.json()["IpfsHash"]

print("Verify bundle CID:", verify_cid)


# Phase2: Model Training and Inference

**Install packages**

In [None]:
# If you just started a fresh Colab T4, run this ONCE.
# Do NOT install custom torch/cuda or bitsandbytes here — it causes conflicts.
%pip install -U "transformers>=4.44,<4.47" "accelerate>=0.33,<0.36" "peft>=0.13,<0.14" \
  "datasets>=2.20,<2.22" einops sentencepiece evaluate scikit-learn \
  web3==6.19.0 eth-account==0.10.0 requests


**Config**

In [None]:
#@title ✅ Edit these for your environment
import os
from pathlib import Path

# --------- Chain RPC + Contracts ----------
RPC_URL        = ""  #@param {type:"string"}
PHASE1_ADDR    = ""                            #@param {type:"string"}
PHASE2_ADDR    = ""                            #@param {type:"string"}
PRIVATE_KEY    = ""                  #@param {type:"string"}  # Sepolia test key ONLY

# ---------- Phase-1 Dataset identity ----------
#here it is assumed that the dataset registered in phase1 is also used here
DATASET_ID_HEX = ""               #@param {type:"string"}
# 0 = auto (next cell will resolve from the registry events); or set an explicit integer
DATASET_VER = int(os.getenv("DATASET_VER", "0"))

# ---------- IPFS / Pinning (optional) ----------
PINATA_JWT     = ""                       #@param {type:"string"}
W3S_TOKEN      = ""  # web3.storage API token                      #@param {type:"string"}
IPFS_PROXY     = ""      #@param {type:"string"}

# ---------- Data ----------
ZIP_URL        = "https://physionet.org/static/published-projects/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip"  #@param {type:"string"}

# ---------- Training ----------
MODEL_REPO     = "Qwen/Qwen2.5-0.5B-Instruct"  #@param {type:"string"}
MAX_STEPS      = 40                            #@param {type:"number"}
BATCH_SIZE     = 2                             #@param {type:"number"}
GRAD_ACCUM     = 4                             #@param {type:"number"}
LR             = 2e-4                          #@param {type:"number"}
SEED           = 42                            #@param {type:"number"}
MAX_LEN        = 384                           #@param {type:"number"}


MASTER_SALT = os.getenv("MASTER_SALT", "")


# ---------- Workspace ----------
WORK      = Path("/content/phase2_work")
RAW_DIR   = WORK/"raw"
CANON_DIR = WORK/"canon"
TOK_DIR   = WORK/"tokenized"
ARTIF_DIR = WORK/"artifacts"

for d in [WORK, RAW_DIR, CANON_DIR, TOK_DIR, ARTIF_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print("OK: config and folders ready:", WORK)


**Imports & helpers**

In [None]:
import os, io, json, time, gzip, shutil, zipfile, random, math, typing, tarfile
import requests
import pandas as pd
import numpy as np

from pathlib import Path
from web3 import Web3
from eth_account import Account
from eth_account.signers.local import LocalAccount

import torch
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    Trainer, TrainingArguments, DataCollatorForLanguageModeling
)
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType

def keccak_hex(b: bytes) -> str:
    return "0x" + Web3.keccak(b).hex()[2:]

def canonical_json(obj) -> bytes:
    # Sorted keys + tight separators for deterministic hashing
    return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False).encode("utf-8")

def pin_json(payload: dict) -> str:
    # Priority: IPFS proxy -> Pinata -> web3.storage -> dummy
    if IPFS_PROXY:
        r = requests.post(f"{IPFS_PROXY.rstrip('/')}/ipfs/pin_json", json=payload, timeout=120)
        r.raise_for_status(); return r.json()["cid"]
    if PINATA_JWT:
        r = requests.post("https://api.pinata.cloud/pinning/pinJSONToIPFS",
                          headers={"Authorization": f"Bearer {PINATA_JWT}"}, json=payload, timeout=120)
        r.raise_for_status(); return r.json()["IpfsHash"]
    if W3S_TOKEN:
        data = io.BytesIO(json.dumps(payload).encode("utf-8"))
        r = requests.post("https://api.web3.storage/upload",
                          headers={"Authorization": f"Bearer {W3S_TOKEN}"}, data=data.getvalue(), timeout=120)
        r.raise_for_status(); return r.json()["cid"]
    return "bafkrei" + "0"*44  # dummy so you can test chain flow

def pin_file(path: Path) -> str:
    content = path.read_bytes()
    if IPFS_PROXY:
        r = requests.post(f"{IPFS_PROXY.rstrip('/')}/ipfs/pin_file",
                          files={"file": (path.name, content, "application/octet-stream")}, timeout=300)
        r.raise_for_status(); return r.json()["cid"]
    if PINATA_JWT:
        r = requests.post("https://api.pinata.cloud/pinning/pinFileToIPFS",
                          headers={"Authorization": f"Bearer {PINATA_JWT}"},
                          files={"file": (path.name, content, "application/octet-stream")}, timeout=300)
        r.raise_for_status(); return r.json()["IpfsHash"]
    if W3S_TOKEN:
        r = requests.post("https://api.web3.storage/upload",
                          headers={"Authorization": f"Bearer {W3S_TOKEN}"}, data=content, timeout=300)
        r.raise_for_status(); return r.json()["cid"]
    return "bafkrei" + "0"*44

def to_bytes32(hexstr: str) -> bytes:
    assert hexstr.startswith("0x") and len(hexstr) == 66, f"bad bytes32: {hexstr}"
    return bytes.fromhex(hexstr[2:])


**Web3 setup & ABIs**

In [None]:
# Connect
w3 = Web3(Web3.HTTPProvider(RPC_URL))
assert w3.is_connected(), "RPC not reachable (check RPC_URL/Infura)"
acct: LocalAccount = Account.from_key(PRIVATE_KEY if PRIVATE_KEY.startswith("0x") else "0x"+PRIVATE_KEY)
print("Using account:", acct.address)
chain_id = w3.eth.chain_id
print("Chain ID:", chain_id)

# Minimal ABIs limited to functions we call
PHASE1_ABI = json.loads("""<add_abi>""")


PHASE2_ABI = (json.loads("""<add_abi>""")

phase1 = w3.eth.contract(address=Web3.to_checksum_address(PHASE1_ADDR), abi=PHASE1_ABI)
phase2 = w3.eth.contract(address=Web3.to_checksum_address(PHASE2_ADDR), abi=PHASE2_ABI)

def send_tx(fn, *args, gas=None):
    nonce = w3.eth.get_transaction_count(acct.address)
    tx = fn(*args).build_transaction({
        "from": acct.address,
        "nonce": nonce,
        "chainId": chain_id,
    })
    # EIP-1559 fees (safe-ish defaults)
    latest = w3.eth.get_block("latest")
    base = latest.get("baseFeePerGas", w3.to_wei(1, "gwei"))
    max_priority = w3.to_wei(2, "gwei")
    tx["maxPriorityFeePerGas"] = max_priority
    tx["maxFeePerGas"] = base + max_priority * 2
    if gas is None:
        try:
            tx["gas"] = int(w3.eth.estimate_gas(tx) * 1.2)
        except Exception:
            tx["gas"] = 600_000
    else:
        tx["gas"] = gas
    signed = acct.sign_transaction(tx)
    txh = w3.eth.send_raw_transaction(signed.rawTransaction)
    rcpt = w3.eth.wait_for_transaction_receipt(txh)
    return rcpt



**Phase-1: Verify dataset root**

In [None]:
# If DATASET_VER isn't set, or is 0, pick latest
if 'DATASET_VER' not in globals() or not int(globals().get('DATASET_VER') or 0):
    _vers = phase1.functions.getDatasetVersions(Web3.to_bytes(hexstr=DATASET_ID_HEX)).call()
    assert _vers, "No versions registered yet for this datasetId"
    DATASET_VER = int(_vers[-1][4])


ds_id = Web3.to_bytes(hexstr=DATASET_ID_HEX)
versions = phase1.functions.getDatasetVersions(ds_id).call()
assert len(versions) >= DATASET_VER, f"Dataset version {DATASET_VER} not found on chain"
onchain_root = Web3.to_hex(versions[DATASET_VER-1][0])
print(f"On-chain datasetRoot (v{DATASET_VER}):", onchain_root)


**Download + canonicalize a small sample**

In [None]:
import gzip

zip_path = WORK / "mimic_demo.zip"
if ZIP_URL and not zip_path.exists():
    r = requests.get(ZIP_URL, stream=True); r.raise_for_status()
    with open(zip_path, "wb") as f:
        for chunk in r.iter_content(8192): f.write(chunk)

if not any(RAW_DIR.iterdir()):
    import zipfile
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(RAW_DIR)

def read_csv_robust(path, nrows=None):
    try:
        return pd.read_csv(path, dtype=str, engine="c", low_memory=False, nrows=nrows)
    except Exception:
        try:
            return pd.read_csv(path, dtype=str, engine="python", on_bad_lines="skip", nrows=nrows)
        except Exception:
            return None

def canonicalize_csv_sample(in_path: Path, out_path: Path, nrows: int = 200):
    df = read_csv_robust(in_path, nrows=nrows)
    if df is None or df.empty: return False
    cols = sorted(df.columns)
    df = df[cols].sort_values(by=cols, na_position="last").reset_index(drop=True)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_path, index=False, encoding="utf-8", lineterminator="\n")
    return True

canon_count = 0
sources = list(sorted(RAW_DIR.rglob("*.csv"))) + list(sorted(RAW_DIR.rglob("*.csv.gz")))
for src in sources:
    to_read = src
    if src.suffix == ".gz" and src.name.endswith(".csv.gz"):
        to_read = src.with_suffix("")  # strip one suffix
        with gzip.open(src, "rb") as fin, open(to_read, "wb") as fout:
            shutil.copyfileobj(fin, fout)
    out = CANON_DIR / Path(to_read).name
    if canonicalize_csv_sample(to_read, out): canon_count += 1

assert canon_count > 0, "No CSV files canonicalized"
print("Canonicalized CSVs:", canon_count)


**Build tiny SFT dataset + tokenizer**

In [None]:
def build_pairs(csv_path: Path, max_rows=120):
    df = read_csv_robust(csv_path, nrows=max_rows)
    if df is None or df.empty: return []
    df = df.fillna("")
    samples = []
    for _, row in df.iterrows():
        rd = row.to_dict()
        prompt = (
            f"You are a clinical assistant. Given the following record from "
            f"{csv_path.name}, provide a short summary:\n{rd}"
        )
        response = "Summary: [Your summary here]"
        samples.append({"prompt": prompt, "response": response})
    return samples

pairs = []
for csv in sorted(CANON_DIR.glob("*.csv")):
    pairs += build_pairs(csv, max_rows=120)
assert len(pairs) > 0, "No training pairs built"
len(pairs)


In [None]:
# Tokenizer FIRST (so we don't use it before it's created)
tok = AutoTokenizer.from_pretrained(MODEL_REPO, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

def format_sample(ex):
    return {"text": ex["prompt"] + "\n" + ex["response"]}

from datasets import Dataset
ds = Dataset.from_list([format_sample(x) for x in pairs])

def tok_fn(batch):
    return tok(batch["text"], truncation=True, max_length=MAX_LEN)

ds_tok = ds.map(tok_fn, batched=True, remove_columns=["text"]).shuffle(seed=SEED)
print(ds_tok)


**Model + LoRA (no bitsandbytes), train**

In [None]:
bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if bf16_ok else torch.float16

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_REPO, torch_dtype=dtype, device_map=None
).to("cuda")

from peft import LoraConfig, get_peft_model, TaskType
peft_cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj"]
)

# IMPORTANT when using gradient checkpointing
base_model.config.use_cache = False
model = get_peft_model(base_model, peft_cfg)

# Make inputs require grad for checkpointing
if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()

# Enable gradient checkpointing (new HF prefers use_reentrant=False)
try:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
except TypeError:
    # for slightly older transformers
    model.gradient_checkpointing_enable()

# Opt-in LoRA params to train
model.print_trainable_parameters()

import random, numpy as np, torch
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)

from transformers import TrainingArguments, Trainer
train_args = TrainingArguments(
    output_dir=str(WORK/"out"),
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    num_train_epochs=1,
    max_steps=MAX_STEPS,
    logging_steps=10,
    save_steps=MAX_STEPS,
    save_total_limit=1,
    remove_unused_columns=False,
    bf16=bf16_ok,
    fp16=not bf16_ok,
    optim="adamw_torch",
    report_to="none",
    gradient_checkpointing=True,   # let Trainer cooperate with GC
)

trainer = Trainer(
    model=model,
    args=train_args,
    tokenizer=tok,  # FutureWarning is fine; still works in 4.x
    train_dataset=ds_tok,
    data_collator=collator
)

train_res = trainer.train()
print("Train done; steps:", train_res.global_step)


In [None]:
import matplotlib.pyplot as plt

# Extract logs from Trainer state
logs = trainer.state.log_history

steps = []
losses = []

for entry in logs:
    if "loss" in entry and "step" in entry:
        steps.append(entry["step"])
        losses.append(entry["loss"])

# Plot
plt.figure(figsize=(6,4))
plt.plot(steps, losses, marker='o', label="Training Loss")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Training Loss over Steps")
plt.grid(True)
plt.legend()
plt.tight_layout()

# Save plot to file
plt.savefig("training_loss_plot.png", dpi=300)

print("Plot saved as training_loss_plot.png")


**Save adapter, bundle artifacts, hashes, pin**

In [None]:
import tarfile

lora_dir = ARTIF_DIR/"lora"
lora_dir.mkdir(exist_ok=True, parents=True)
model.save_pretrained(lora_dir)
tok.save_pretrained(lora_dir)

weights_tar = ARTIF_DIR/"lora_adapter.tar.gz"
with tarfile.open(weights_tar, "w:gz") as tar:
    tar.add(lora_dir, arcname="lora_adapter")

training_config = {
    "datasetId": DATASET_ID_HEX,
    "datasetVersion": DATASET_VER,
    "datasetRoot": onchain_root,
    "modelRepo": MODEL_REPO,
    "lora": {"r": peft_cfg.r, "alpha": peft_cfg.lora_alpha, "dropout": peft_cfg.lora_dropout},
    "train": {"max_steps": MAX_STEPS, "batch": BATCH_SIZE, "grad_accum": GRAD_ACCUM, "lr": LR, "seed": SEED}
}
cfg_path = ARTIF_DIR/"training_config.json"
cfg_path.write_text(json.dumps(training_config, indent=2))
config_hash = keccak_hex(canonical_json(training_config))

metrics = {
    "final_loss": float(trainer.state.log_history[-1].get("loss", 0.0)),
    "steps": int(train_res.global_step),
    "seed": SEED
}
metrics_path = ARTIF_DIR/"metrics.json"
metrics_path.write_text(json.dumps(metrics, indent=2))
metrics_hash = keccak_hex(canonical_json(metrics))

weights_hash = keccak_hex(weights_tar.read_bytes())
print("configHash:", config_hash)
print("metricsHash:", metrics_hash)
print("weightsHash:", weights_hash)

# Pack artifacts for pinning
artifacts_tar = ARTIF_DIR/"artifacts.tar.gz"
with tarfile.open(artifacts_tar, "w:gz") as tar:
    tar.add(weights_tar, arcname="lora_adapter.tar.gz")
    tar.add(cfg_path, arcname="training_config.json")
    tar.add(metrics_path, arcname="metrics.json")

artifacts_cid = pin_file(artifacts_tar)
print("artifacts CID:", artifacts_cid)


**On-chain: create model → start run → finalize**

In [None]:
# Option 2 (clean): create NEW model (auto-generated modelId) + alsoAnchor = true
# Uses existing names: ds_id, DATASET_VER, config_hash, phase2, send_tx, acct, weights_hash, metrics_hash, artifacts_cid

# 1) Prepare args
zero32      = b"\x00" * 32
init_cfg_b  = to_bytes32(config_hash)                          # bytes32 from your training_config hash
code_hash_b = globals().get("code_hash_bytes32", zero32)       # optional
arch_hash_b = globals().get("arch_hash_bytes32", zero32)       # optional
model_uri   = globals().get("model_uri_string", f"ipfs://{artifacts_cid}" if "artifacts_cid" in globals() else "")

# 2) Create model with auto-generated modelId and anchor = true
rcpt = send_tx(
    phase2.functions.createModel,
    zero32,            # modelId=0 => contract autogenerates
    ds_id,             # bytes32 (built earlier from DATASET_ID_HEX)
    int(DATASET_VER),  # uint256
    code_hash_b,       # bytes32
    arch_hash_b,       # bytes32
    model_uri,         # string
    init_cfg_b,        # bytes32 (configHash)
    True               # alsoAnchor
)
print("createModel tx:", rcpt.transactionHash.hex())

# 3) Extract the NEW modelId from ModelCreated event
try:
    evs = phase2.events.ModelCreated().process_receipt(rcpt)
    assert evs, "ModelCreated event not found in receipt"
    new_model_id = evs[0]["args"]["modelId"]  # bytes32
except Exception:
    # Fallback: decode via topic signature (no ABI mismatch warnings)
    MODEL_CREATED_SIG = Web3.keccak(
        text="ModelCreated(bytes32,address,bytes32,uint256,bytes32,bytes32,bytes32,string)"
    ).hex()
    new_model_id = None
    for lg in rcpt.logs:
        if lg["address"].lower() == phase2.address.lower() and lg["topics"]:
            if lg["topics"][0].hex().lower() == MODEL_CREATED_SIG.lower():
                new_model_id = Web3.to_bytes(hexstr=lg["topics"][1].hex())
                break
    assert new_model_id is not None, "Could not extract new modelId from logs"

print("New modelId:", "0x" + new_model_id.hex())

# 4) Wire the new id for later cells
model_id_bytes = new_model_id
print("model_id_bytes set.")

# 5) Start + finalize a run under THIS new model
run_receipt = send_tx(phase2.functions.startTrainingRun, model_id_bytes, init_cfg_b)
ev = phase2.events.TrainingStarted().process_receipt(run_receipt)
run_id = int(ev[0]["args"]["runId"]) if ev else None
print("runId:", run_id)

final_receipt = send_tx(
    phase2.functions.finalizeTrainingRun,
    run_id,
    Web3.to_bytes(hexstr=weights_hash),   # NOTE: snake_case variable names from earlier cell
    Web3.to_bytes(hexstr=metrics_hash),
    f"ipfs://{artifacts_cid}" if "artifacts_cid" in globals() else ""
)
print("finalized run tx:", final_receipt.transactionHash.hex())
print("✅ New model created, anchored, run started+finalized")


**Deterministic salts [Still being tested]**

In [None]:
# --- Deterministic salts (robust) ---
import os, hmac, hashlib, secrets, base64, string
HEX = set(string.hexdigits)

def _load_master_salt():
    s = os.getenv("MASTER_SALT")
    if not s:
        key = secrets.token_bytes(32)
        # Print ONLY if we just generated it (so you can copy/save it securely).
        print("⚠️ Save this MASTER_SALT (hex) somewhere safe:", key.hex())
        return key
    if isinstance(s, bytes):
        if len(s) != 32:
            raise ValueError("MASTER_SALT bytes must be exactly 32 bytes")
        return s
    # str: try hex or base64
    t = s.strip()
    if t.lower().startswith("0x"):
        t = t[2:]
    if len(t) in (64,) and all(c in HEX for c in t):
        return bytes.fromhex(t)
    try:
        key = base64.b64decode(t, validate=True)
        if len(key) != 32:
            raise ValueError
        return key
    except Exception:
        raise ValueError("MASTER_SALT must be 32 bytes (hex '0x..' or base64)")

MASTER_SALT = _load_master_salt()

def derive_salt(master: bytes, model_id_b32: bytes, idx: int, tag: bytes, nbytes: int = 16) -> bytes:
    """
    salt = HMAC_SHA256(master, model_id_b32 || uint64_be(idx) || tag)[:nbytes]
    - master: 32 bytes
    - model_id_b32: exactly 32 bytes
    - idx: stable per-inference index in the batch
    - tag: e.g., b'in' or b'out' (domain separation)
    - nbytes: 16 or 32 (default 16 = 128-bit)
    """
    if not (isinstance(master, (bytes, bytearray)) and len(master) == 32):
        raise ValueError("master must be 32 bytes")
    if not (isinstance(model_id_b32, (bytes, bytearray)) and len(model_id_b32) == 32):
        raise ValueError("model_id_b32 must be 32 bytes")
    if not isinstance(idx, int) or idx < 0:
        raise ValueError("idx must be a non-negative int")
    if not (isinstance(tag, (bytes, bytearray)) and 1 <= len(tag) <= 16):
        raise ValueError("tag must be 1..16 bytes")
    if nbytes not in (16, 32):
        raise ValueError("nbytes should be 16 or 32")
    msg = model_id_b32 + idx.to_bytes(8, "big") + bytes(tag)
    return hmac.new(master, msg, hashlib.sha256).digest()[:nbytes]


**Inference + Merkle batch**

In [None]:
# --- helpers for ΔNLL ---
import torch, numpy as np, json
SEED = globals().get("SEED", 42)

# pick the tokenizer you actually have in scope
_tok = tokenizer if "tokenizer" in globals() else tok

def norm_ipfs(value: str) -> str:
    """Return a clean ipfs://… URI no matter what you pass in (CID or ipfs://CID)."""
    s = normalize_cid(value)
    return f"ipfs://{s}" if s else ""

def normalize_cid(value: str) -> str:
    s = (value or "").strip().strip('"').strip("'")
    return s[7:] if s.startswith("ipfs://") else s


def generate(text, max_new=128):
    inputs = _tok(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new, do_sample=False)
    return _tok.decode(out[0], skip_special_tokens=True)

def total_nll_on_target(model, tok, prompt: str, target: str) -> float:
    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        enc_p = tok(prompt, return_tensors="pt").to(device)
        enc_t = tok(target, return_tensors="pt", add_special_tokens=False).to(device)
        input_ids = torch.cat([enc_p["input_ids"], enc_t["input_ids"]], dim=1)
        attn_mask = torch.cat([enc_p["attention_mask"], enc_t["attention_mask"]], dim=1)
        labels = input_ids.clone()
        labels[:, :enc_p["input_ids"].shape[1]] = -100  # ignore prompt
        out = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
        tgt_tokens = (labels != -100).sum().item()
        return float(out.loss) * max(tgt_tokens, 1)

def token_drop_report(record_text: str, keys: list[str], base_out: str) -> dict:
    base_nll = total_nll_on_target(model, _tok, record_text, base_out)
    attrs = []
    for k in keys:
        masked = record_text.replace(k, "")
        ablated = total_nll_on_target(model, _tok, masked, base_out)
        attrs.append({"feature": k, "delta_nll": float(ablated - base_nll)})
    return {"base_output": base_out[:400], "attributions": attrs}

rng = np.random.default_rng(SEED + 1)
events = []
for i in range(3):  # a small batch
    ex = pairs[rng.integers(0, len(pairs))]
    prompt = ex["prompt"]
    out = generate(prompt)

    # --- deterministic salts (persistable via MASTER_SALT + (model_id, idx, tag)) ---
    input_salt  = derive_salt(MASTER_SALT, model_id_bytes, i, b"in")
    output_salt = derive_salt(MASTER_SALT, model_id_bytes, i, b"out")

    # Show the prompt and output for inspection / logging
    print(f"\n=== Inference {i} ===")
    print("Prompt:\n", prompt)
    print("Output:\n", out)

    # Optionally, save to a local log file for future proofs
    with open("inference_log.jsonl", "a", encoding="utf-8") as f:
        json.dump({
            "idx": i,
            "prompt": prompt,
            "output": out,
            "input_salt_hex": "0x" + input_salt.hex(),
            "output_salt_hex": "0x" + output_salt.hex()
        }, f)
        f.write("\n")

    input_hash  = keccak_hex(input_salt  + prompt.encode("utf-8"))
    output_hash = keccak_hex(output_salt + out.encode("utf-8"))

    # crude feature list from prompt text
    keys = [p.split(":")[0].strip() for p in prompt.split(",") if ":" in p][:6] or \
           ["age","sex","admission","diagnosis","med","lab"]

    # --- stronger XAI: token-drop ΔNLL instead of length delta ---
    xai = token_drop_report(prompt, keys, out)
    xai_bytes = canonical_json({
        "prompt_sha": input_hash,
        "output_sha": output_hash,
        "method": "token_drop_nll",
        "report": xai
    })
    xai_hash = keccak_hex(xai_bytes)
    try:
        xai_cid = pin_json(json.loads(xai_bytes.decode("utf-8")))
    except Exception:
        xai_cid = "bafkrei" + "1"*44

    events.append({
        "inputHash":  input_hash,
        "outputHash": output_hash,
        "xaiHash":    xai_hash,
        "xaiCID":     norm_ipfs(xai_cid)
    })

len(events), events[0]


In [None]:
# --- Token-drop ΔNLL bar chart---
import json
import matplotlib.pyplot as plt
from textwrap import shorten

def plot_token_drop_delta_nll(xai_obj,
                              save_path="token_drop_delta_nll.png",
                              sort_by_value=False,
                              max_label_len=28):
    """
    xai_obj: dict or JSON string with structure:
      {"method":"token_drop_nll", "report":{"attributions":[{"delta_nll":..., "feature": ...}, ...]}}
    """
    # Accept dict or JSON string
    if isinstance(xai_obj, str):
        xai_obj = json.loads(xai_obj)

    report = xai_obj.get("report", xai_obj)
    attrs = report.get("attributions", [])

    vals = [float(a.get("delta_nll", 0.0)) for a in attrs]
    raw_labels = [str(a.get("feature", "")).strip().strip("'").strip('"') for a in attrs]

    # Nice labels for the features you showed
    def pretty_label(s: str) -> str:
        mapping = {
            "provide a short summary": "Instruction (summary)",
            "amountuom": "Amount UOM",
            "caregiver_id": "Caregiver ID",
            "endtime": "End Time",
            "hadm_id": "Admission ID (HADM)",
            "itemid": "Item ID",
        }
        return mapping.get(s, s.replace("_", " ").title())

    labels = [pretty_label(s) for s in raw_labels]

    # Optional: sort by ΔNLL descending
    if sort_by_value:
        pairs = sorted(zip(labels, vals), key=lambda x: x[1], reverse=True)
        if pairs:
            labels, vals = map(list, zip(*pairs))

    fig, ax = plt.subplots(figsize=(6.5, 3.8))
    ax.bar(range(len(vals)), vals)
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels([shorten(lbl, width=max_label_len, placeholder="…") for lbl in labels],
                       rotation=25, ha="right")
    ax.set_ylabel("ΔNLL (token-drop)")
    ax.set_title("Token-drop ΔNLL attributions")
    ax.grid(axis="y", linestyle=":", linewidth=0.7)

    fig.tight_layout()
    fig.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Saved figure → {save_path}")

# ---- Example: paste your JSON here (from your XAI bundle/report) ----
xai_json = {"method":"token_drop_nll","output_sha":"0x055ad709fab67008ccfd75d6954f042ac52609cf8150582523cb4c9c03d421c2","prompt_sha":"0xe82857164c571b46760ec3008cc4b3b64d1a3e1766304ca5437f4d323deebcaf","report":{"attributions":[{"delta_nll":0.8782769441604614,"feature":"provide a short summary"},{"delta_nll":11.175613462924957,"feature":"'amountuom'"},{"delta_nll":18.746244430541992,"feature":"'caregiver_id'"},{"delta_nll":3.121084749698639,"feature":"'endtime'"},{"delta_nll":14.985986232757568,"feature":"'hadm_id'"},{"delta_nll":7.941444456577301,"feature":"'itemid'"}],"base_output":"You are a clinical assistant. Given the following record from ingredientevents.csv, provide a short summary:\n{'amount': '158.16667138040066', 'amountuom': 'ml', 'caregiver_id': '12929', 'endtime': '2132-12-16 06:23:00', 'hadm_id': '20626031', 'itemid': '220490', 'linkorderid': '9595817', 'orderid': '9595817', 'originalamount': '0', 'originalrate': '250', 'rate': '10', 'rateuom': 'mL/hour', 'startt"}}

# Make & auto-save the figure
plot_token_drop_delta_nll(xai_json, save_path="token_drop_delta_nll.png", sort_by_value=False)


In [None]:
# Build Merkle root over keccak32(inputHash||outputHash||xaiHash), with sorted-pair parenting
def b32(x: str) -> bytes:
    assert x.startswith("0x") and len(x)==66
    return bytes.fromhex(x[2:])

def parent_sorted(a: bytes, b: bytes) -> bytes:
    return Web3.keccak(a+b) if a < b else Web3.keccak(b+a)

leaves = [Web3.keccak(b32(e["inputHash"]) + b32(e["outputHash"]) + b32(e["xaiHash"])) for e in events]

def merkle_root(hashes: list[bytes]) -> bytes:
    if len(hashes) == 1:
        return hashes[0]
    level = hashes[:]
    while len(level) > 1:
        if len(level) % 2 == 1: level.append(level[-1])
        nxt = []
        for i in range(0,len(level),2):
            a,b = level[i], level[i+1]
            nxt.append(parent_sorted(a,b))
        level = nxt
    return level[0]

root_bytes = merkle_root(leaves)
batchRoot = "0x"+root_bytes.hex()

manifest = {
  "modelId": "0x" + model_id_bytes.hex(),
  "count": len(events),
  "leaf_rule": "keccak32(inputHash||outputHash||xaiHash)",
  "pair_rule": "sorted-pair keccak",
  "events": events,
  "created_at": int(time.time())
}
manifest["saltScheme"] = "hmac_sha256(master, modelId||uint64(idx)||tag)[:16]"

batch_cid = pin_json(manifest)
print("batchRoot:", batchRoot)
print("batch CID:", batch_cid)


**Commit inference batch on-chain**

In [None]:
# Safer conversion helpers
def as_bytes32(x):
    # Accept either bytes32 already, or "0x..." hex string
    if isinstance(x, (bytes, bytearray)):
        assert len(x) == 32, f"expected 32 bytes, got {len(x)}"
        return bytes(x)
    s = str(x).strip()
    if s.startswith("0x") or s.startswith("0X"):
        s = s[2:]
    assert len(s) == 64, f"expected 32-byte hex (64 nibbles), got {len(s)} nibbles"
    return bytes.fromhex(s)

# You computed these earlier:
# - model_id_bytes (bytes)
# - root_bytes (bytes)  ← direct from merkle_root()
# - batchRoot (string "0x...")  ← derived from root_bytes
# - batch_cid (string)

# Use the BYTES value directly (avoid hex parsing issues)
batch_root_bytes = as_bytes32(root_bytes)

rcpt_batch = send_tx(
    phase2.functions.commitInferenceBatch,
    model_id_bytes,
    batch_root_bytes,
    len(manifest["events"]),
    norm_ipfs(batch_cid)
)

print("commitInferenceBatch tx:", rcpt_batch.transactionHash.hex())
print("DONE ✅")


**Reveal script**

In [None]:
# --- AUTO: use existing PHASE2_ABI (or existing `phase2`) to fetch latest batch CID+root ---
import json
from web3 import Web3

try:
    RPC_URL
except NameError:
    RPC_URL = <YOUR_RPC_URL>

try:
    PHASE2_ADDR
except NameError:
    raise RuntimeError("Set PHASE2_ADDR before running this cell")

w3 = Web3(Web3.HTTPProvider(RPC_URL))
assert w3.is_connected(), "RPC not reachable"

# Reuse contract instance if present; otherwise build it with PHASE2_ABI you loaded earlier
try:
    phase2  # already defined elsewhere?
    # sanity: if address differs, rebuild
    if phase2.address.lower() != Web3.to_checksum_address(PHASE2_ADDR).lower():
        phase2 = w3.eth.contract(address=Web3.to_checksum_address(PHASE2_ADDR), abi=PHASE2_ABI)
except NameError:
    phase2 = w3.eth.contract(address=Web3.to_checksum_address(PHASE2_ADDR), abi=PHASE2_ABI)

def _normalize_cid(value: str) -> str:
    s = (value or "").strip().strip('"').strip("'")
    return s[7:] if s.startswith("ipfs://") else s

def get_latest_batch_for_model(_phase2, model_id_b: bytes, start_block=0):
    logs = _phase2.events.InferenceBatchCommitted().get_logs(
        fromBlock=start_block, toBlock="latest", argument_filters={"modelId": model_id_b}
    )
    if not logs:
        raise RuntimeError("No InferenceBatchCommitted events found for this modelId")
    a = logs[-1]["args"]
    return _normalize_cid(a["batchCID"]), "0x" + a["batchRoot"].hex()

# `model_id_bytes` must already be set earlier in your flow
DEPLOY_BLOCK = 8984586  # optionally set to your Phase-2 deploy block
BATCH_MANIFEST_CID, BATCH_ROOT_ONCHAIN = get_latest_batch_for_model(phase2, model_id_bytes, start_block=DEPLOY_BLOCK)
print("Auto-selected BATCH_MANIFEST_CID:", BATCH_MANIFEST_CID)
print("Batch root (on-chain):", BATCH_ROOT_ONCHAIN)

# Gateways (keep as-is / or your existing list)
IPFS_GATEWAYS = [
    "https://gateway.pinata.cloud/ipfs/",
    "https://ipfs.io/ipfs/",
]





**Added because manually filling salt hex from the inference_log.jsonl didnt' work for some reason**

In [None]:
# Fill prompt_text / output_text / salts from your saved JSONL
import json
from pathlib import Path

JSONL_PATHS = [
    "/mnt/data/inference_log.jsonl",  # attached path
    "inference_log.jsonl",            # local working dir fallback
]
JSONL_PATH = next((p for p in JSONL_PATHS if Path(p).exists()), None)
assert JSONL_PATH, "inference_log.jsonl not found (checked /mnt/data and CWD)."

# Choose which row to use: "latest" or an integer index
TARGET = "latest"  # or e.g., 0

rows = []
with open(JSONL_PATH, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            rows.append(json.loads(line))
        except Exception:
            pass
assert rows, f"No rows in {JSONL_PATH}"

row = rows[-1] if TARGET == "latest" else rows[int(TARGET)]

prompt_text = row["prompt"]
output_text = row["output"]

# Handle either hex strings or raw bytes saved as repr
def _norm_hex(v):
    if isinstance(v, (bytes, bytearray)):
        return "0x" + bytes(v).hex()
    s = str(v).strip().strip('"').strip("'")
    while s.lower().startswith("0x"):
        s = s[2:]
    return "0x" + s

input_salt_hex  = _norm_hex(row.get("input_salt_hex", row.get("input_salt", "")))
output_salt_hex = _norm_hex(row.get("output_salt_hex", row.get("output_salt", "")))

print("Loaded from:", JSONL_PATH, "| idx:", row.get("idx"))
print("input_salt_hex :", input_salt_hex)
print("output_salt_hex:", output_salt_hex)
print("prompt_text[:90]:", prompt_text[:90].replace("\n"," "))
print("output_text[:90]:", output_text[:90].replace("\n"," "))

# (Optional) quick match against an already-loaded manifest `events`
try:
    from web3 import Web3
    def keccak_hex(b: bytes) -> str: return "0x" + Web3.keccak(b).hex()[2:]
    ih = keccak_hex(bytes.fromhex(input_salt_hex[2:])  + prompt_text.encode("utf-8"))
    oh = keccak_hex(bytes.fromhex(output_salt_hex[2:]) + output_text.encode("utf-8"))
    if "events" in globals():
        match_i = next((i for i,e in enumerate(events)
                        if e["inputHash"].lower()==ih.lower()
                        and e["outputHash"].lower()==oh.lower()), None)
        print("Matched index in manifest:", match_i)
except Exception as e:
    print("Manifest match check skipped:", e)


In [None]:
# ----- Robust reveal OR non-reveal membership check -----
import json, requests, base64, ast, string
from web3 import Web3

HEXCHARS = set(string.hexdigits)

def b32(x: str|bytes) -> bytes:
    if isinstance(x, (bytes, bytearray)):
        assert len(x) == 32, f"expected 32 bytes, got {len(x)}"
        return bytes(x)
    s = str(x).strip().strip('"').strip("'")
    while s.lower().startswith("0x"):  # strip ANY number of 0x prefixes
        s = s[2:]
    if len(s) != 64:
        raise ValueError(f"expected 64 hex chars (32 bytes), got {len(s)}")
    return bytes.fromhex(s)

def keccak_hex(b: bytes) -> str:
    return "0x" + Web3.keccak(b).hex()[2:]

def parent_sorted(a: bytes, b: bytes) -> bytes:
    return Web3.keccak(a+b) if a < b else Web3.keccak(b+a)

def merkle_root(hashes: list[bytes]) -> bytes:
    if not hashes: raise ValueError("no leaves")
    level = hashes[:]
    while len(level) > 1:
        if len(level) % 2 == 1: level.append(level[-1])
        nxt = []
        for i in range(0, len(level), 2):
            nxt.append(parent_sorted(level[i], level[i+1]))
        level = nxt
    return level[0]

def merkle_proof(hashes: list[bytes], idx: int) -> list[bytes]:
    if not (0 <= idx < len(hashes)): raise IndexError("idx out of range")
    proof, level_idx, level = [], idx, hashes[:]
    while len(level) > 1:
        if len(level) % 2 == 1: level.append(level[-1])
        nxt = []
        for i in range(0, len(level), 2):
            L, R = level[i], level[i+1]
            if i == level_idx or i+1 == level_idx:
                proof.append(R if i == level_idx else L)
                level_idx = len(nxt)
            nxt.append(parent_sorted(L, R))
        level = nxt
    return proof

def normalize_cid(value: str) -> str:
    s = (value or "").strip().strip('"').strip("'")
    if s.startswith("ipfs://"): s = s[7:]
    if s.startswith("ipfs/"):   s = s[5:]
    return s

def norm_ipfs(value: str) -> str:
    """Return a clean ipfs://… URI no matter what you pass in (CID or ipfs://CID)."""
    s = normalize_cid(value)
    return f"ipfs://{s}" if s else ""

def fetch_manifest_json(cid: str, gateways=None) -> dict:
    gateways = gateways or [
        "https://gateway.pinata.cloud/ipfs/",
        "https://ipfs.io/ipfs/",
        "https://dweb.link/ipfs/",
    ]
    last_err = None
    for base in gateways:
        url = base.rstrip("/") + "/" + cid
        try:
            r = requests.get(url, timeout=60, headers={"Accept": "application/json"})
            r.raise_for_status()
            return r.json()
        except Exception as e:
            last_err = e; continue
    raise RuntimeError(f"Could not fetch manifest {cid}: {last_err}")

def get_latest_batch_for_model(phase2_contract, model_id_b, rcpt_batch_obj=None, start_block=0):
    if rcpt_batch_obj is not None:
        try:
            ev = phase2_contract.events.InferenceBatchCommitted().process_receipt(rcpt_batch_obj)
            if ev:
                a = ev[0]["args"]
                return normalize_cid(a["batchCID"]), "0x"+a["batchRoot"].hex()
        except Exception as e:
            print("Receipt parse warn:", e)
    logs = phase2_contract.events.InferenceBatchCommitted().get_logs(
        fromBlock=start_block, toBlock="latest", argument_filters={"modelId": model_id_b}
    )
    if not logs: raise RuntimeError("No InferenceBatchCommitted events for this modelId")
    a = logs[-1]["args"]
    return normalize_cid(a["batchCID"]), "0x"+a["batchRoot"].hex()

# If phase2 isn't in scope (fallback to minimal wiring from earlier blocks)
try:
    phase2  # already in scope?
except NameError:
    phase2 = w3.eth.contract(address=Web3.to_checksum_address(PHASE2_ADDR), abi=PHASE2_ABI)

# === Resolve CID & on-chain root (handles manual override or auto-pick) ===
cid_override = normalize_cid(globals().get("BATCH_MANIFEST_CID", ""))
if cid_override:
    # We still fetch the on-chain root for cross-check
    _, batch_root_onchain = get_latest_batch_for_model(phase2, model_id_bytes, globals().get("rcpt_batch"), 0)
    cid = cid_override
else:
    DEPLOY_BLOCK = 0  # set to Phase-2 deploy block if you want faster scans
    cid, batch_root_onchain = get_latest_batch_for_model(
        phase2, model_id_bytes, globals().get("rcpt_batch"), DEPLOY_BLOCK
    )
print("CID:", cid)
print("On-chain batchRoot:", batch_root_onchain)

# Fetch manifest → compute root → compare to on-chain
manifest = fetch_manifest_json(cid)
events = manifest["events"]
assert len(events) == manifest["count"] == len(events), "manifest count mismatch"

leaves = [Web3.keccak(b32(e["inputHash"]) + b32(e["outputHash"]) + b32(e["xaiHash"])) for e in events]
root_bytes = merkle_root(leaves)
onchain_root_bytes = b32(batch_root_onchain)

assert root_bytes == onchain_root_bytes, (
    f"Computed root != on-chain root\n"
    f"computed: 0x{root_bytes.hex()}\n"
    f"onchain : 0x{onchain_root_bytes.hex()}"
)
print("Batch root matches on-chain:", "0x" + root_bytes.hex())

# Decide mode based on salts
def looks_like_placeholder(s):
    return isinstance(s, str) and s.strip() in ("", "0x...", "0x", "...")

REVEAL_MODE = not (looks_like_placeholder(globals().get("input_salt_hex","")) or
                   looks_like_placeholder(globals().get("output_salt_hex","")))
print("Mode:", "FULL REVEAL" if REVEAL_MODE else "NON-REVEAL (membership only)")

if REVEAL_MODE:
    def salt_from_any(x) -> bytes:
        if isinstance(x, (bytes, bytearray)): return bytes(x)
        s = str(x).strip().strip('"').strip("'")
        if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
            try: return ast.literal_eval(s)
            except Exception: pass
        h = s[2:] if s.lower().startswith("0x") else s
        if len(h) >= 2 and len(h) % 2 == 0 and all(c in HEXCHARS for c in h):
            try: return bytes.fromhex(h)
            except Exception: pass
        try: return base64.b64decode(s, validate=True)
        except Exception: pass
        raise ValueError(f"Unrecognized salt format: {x!r}")

    inp_salt = salt_from_any(input_salt_hex)
    out_salt = salt_from_any(output_salt_hex)
    print(f"Salt lengths → input: {len(inp_salt)} bytes, output: {len(out_salt)} bytes")

    input_hash  = keccak_hex(inp_salt + prompt_text.encode("utf-8"))
    output_hash = keccak_hex(out_salt + output_text.encode("utf-8"))
    print("Recomputed inputHash :", input_hash)
    print("Recomputed outputHash:", output_hash)

    match_i = next((i for i,e in enumerate(events)
                    if e["inputHash"].lower()==input_hash.lower()
                    and e["outputHash"].lower()==output_hash.lower()), None)
    assert match_i is not None, "No matching event found in manifest (check salts/prompt/output)"
    xai_hash = events[match_i]["xaiHash"]
    print("Matched index:", match_i, "xaiHash:", xai_hash)
else:
    TARGET_INDEX = 0  # change if you want a different row
    match_i = TARGET_INDEX
    assert 0 <= match_i < len(events), "TARGET_INDEX out of range"
    xai_hash = events[match_i]["xaiHash"]
    input_hash = events[match_i]["inputHash"]
    output_hash = events[match_i]["outputHash"]
    print(f"Picked event #{match_i} from manifest (non-reveal).")

# Build proof & verify
proof_bytes = merkle_proof(leaves, match_i)
print("Proof length:", len(proof_bytes))

batch_root_hex = "0x" + root_bytes.hex()  # for any later JSON/UI use

ih = b32(input_hash); oh = b32(output_hash); xh = b32(xai_hash)
ok = phase2.functions.verifyBatchMembership(ih, oh, xh, proof_bytes, root_bytes).call()
print("On-chain verification result:", ok)
assert ok, "verifyBatchMembership returned false (check inputs)"
print("✅ Membership proof verified", "(full reveal)" if REVEAL_MODE else "(non-reveal)")


**Merkle Batch Membership on-chain Verification Helpers**

In [None]:
def verify_by_index(idx: int = 0):
    assert 0 <= idx < len(events), "idx out of range"
    ih = b32(events[idx]["inputHash"])
    oh = b32(events[idx]["outputHash"])
    xh = b32(events[idx]["xaiHash"])
    proof = merkle_proof(leaves, idx)
    ok = phase2.functions.verifyBatchMembership(ih, oh, xh, proof, root_bytes).call()
    print(f"Index {idx} →", ok)
    return ok

# Example:
verify_by_index(0)


In [None]:
def verify_all():
    ok_count = 0
    for i in range(len(events)):
        ih = b32(events[i]["inputHash"]); oh = b32(events[i]["outputHash"]); xh = b32(events[i]["xaiHash"])
        proof = merkle_proof(leaves, i)
        ok = phase2.functions.verifyBatchMembership(ih, oh, xh, proof, root_bytes).call()
        if ok: ok_count += 1
        else: print("Mismatch at index", i)
    print(f"{ok_count}/{len(events)} verified")

# Example:
verify_all()


# Phase3: Generated Output Verification

In [None]:
# === PHASE 3 — attach OutputAuthentication===
import json, os
from web3 import Web3

assert 'w3' in globals() and 'acct' in globals() and 'send_tx' in globals(), "Run Phase-2 cells first."

# 1) Put your Remix-deployed address here (or via env PHASE3_ADDR)
PHASE3_ADDR = os.getenv("PHASE3_ADDR", "<add_address>")

# 2) Paste the ABI JSON from Remix between the triple quotes:
PHASE3_ABI = json.loads(r"""<add_abi>""")

phase3 = w3.eth.contract(address=Web3.to_checksum_address(PHASE3_ADDR), abi=PHASE3_ABI)
print("Phase-3 attached at:", phase3.address)

# Optional sanity: confirm you're the Phase-2 model owner (must match msg.sender for storeContentHash)
try:
    owner = phase2.functions.getModel(model_id_bytes).call()[0]  # owner is first field in your Phase-2 ABI
    print("Phase-2 model owner:", owner)
    print("Your tx sender    :", acct.address)
except Exception as e:
    print("Owner check skipped:", e)


**Generate with Phase-2 model ➜ hash raw output ➜ store in Phase-3**

*Ensuring that the correct phase2 ABI is loaded*

In [None]:
# --- Rebind phase2 with the FULL ABI
PHASE2_ADDR = Web3.to_checksum_address("add_address")

PHASE2_ABI_FULL = json.loads(r"""<add_abi>""")

phase2 = w3.eth.contract(address=PHASE2_ADDR, abi=PHASE2_ABI_FULL)

# sanity
ev_names = {e['name'] for e in phase2.abi if e.get('type') == 'event'}
print("Has InferenceBatchCommitted?", "InferenceBatchCommitted" in ev_names)
print("Phase-2 address:", phase2.address)

# --- robust latest-batch helper (uses event if present; falls back to raw log scan) ---
from eth_abi import decode as abi_decode

def latest_phase2_batch(model_id_b, from_block=0, to_block="latest"):
    # A) try via ABI event
    try:
        ev = phase2.events.InferenceBatchCommitted()
        logs = ev.get_logs(fromBlock=from_block, toBlock=to_block, argument_filters={"modelId": model_id_b})
        if logs:
            a = logs[-1]["args"]
            cid = a.get("batchCID", None)
            root_b32 = a.get("batchRoot", None)
            root_hex = "0x" + root_b32.hex() if isinstance(root_b32, (bytes, bytearray)) else (root_b32 if isinstance(root_b32, str) else None)
            return cid, root_hex
    except Exception as e:
        # print("Event path failed:", e)
        pass

    # B) raw log scan by topic (works even if ABI is wrong)
    topic0 = Web3.keccak(text="InferenceBatchCommitted(bytes32,bytes32,uint256,string)")
    flt = {
        "fromBlock": from_block,
        "toBlock": to_block,
        "address": phase2.address,
        "topics": [topic0, model_id_b]  # topics[1] is indexed modelId
    }
    logs = w3.eth.get_logs(flt)
    if logs:
        lg = logs[-1]
        # topics[2] is indexed batchRoot
        batch_root_hex = "0x" + lg["topics"][2].hex()[2:].rjust(64, "0")
        # data encodes (uint256 count, string batchCID)
        data_bytes = bytes.fromhex(lg["data"][2:])
        try:
            _count, batch_cid = abi_decode(["uint256","string"], data_bytes)
        except Exception:
            batch_cid = None
        return batch_cid, batch_root_hex

    return None, None


In [None]:
print("Phase-3 phase2():", phase3.functions.debugOwnerAndSender(model_id_bytes).call()[2])
print("owner (phase2.getModel):", phase2.functions.getModel(model_id_bytes).call()[0])
print("owner (phase3.ownerOf): ", phase3.functions.ownerOf(model_id_bytes).call())


In [None]:
# === PHASE 3 — generate -> hash -> store (raw content authentication) ===
import os, time, json
import numpy as np
from web3 import Web3
from web3._utils.events import get_event_data
from eth_abi import decode as abi_decode  # available with web3 install

# -- tiny helpers
def _rng():
    return np.random.default_rng(12345)

def _pick_prompt():
    if 'pairs' in globals() and isinstance(pairs, (list, tuple)) and len(pairs) > 0:
        i = _rng().integers(0, len(pairs))
        cand = pairs[i]
        # support dict or tuple forms
        if isinstance(cand, dict) and "prompt" in cand:
            return cand["prompt"]
        if isinstance(cand, (list, tuple)) and len(cand) > 0:
            return cand[0]
    return "Summarize: {patient_id: 10293, admission_type: EMERGENCY, diagnosis: Pneumonia}"

_tok = tokenizer if "tokenizer" in globals() else tok  # whichever name you used in Phase 2

def generate(text, max_new=160):
    inputs = _tok(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new, do_sample=False)
    return _tok.decode(out[0], skip_special_tokens=True)

prompt_text = _pick_prompt()
output_text = generate(prompt_text)
print("OUTPUT preview:", output_text[:220], "...")

# Content hash (bytes32)
content_hash = Web3.keccak(text=output_text)
print("contentHash:", "0x" + content_hash.hex())

# ---- Find latest Phase-2 batch (CID + root) robustly ----
def _event_exists(contract, name: str) -> bool:
    try:
        # quick check against ABI JSON
        for e in getattr(contract, "abi", []):
            if e.get("type") == "event" and e.get("name") == name:
                return True
        return False
    except Exception:
        return False

def _latest_batch_from_event(phase2_contract, model_id_b, from_block=0, to_block="latest"):
    """
    Requires the ABI to include:
      event InferenceBatchCommitted(bytes32 indexed modelId, bytes32 indexed batchRoot, uint256 count, string batchCID);
    Returns (cid_str, root_hex or None) or (None, None) if not found.
    """
    if not _event_exists(phase2_contract, "InferenceBatchCommitted"):
        raise RuntimeError("Phase-2 ABI lacks InferenceBatchCommitted")

    ev = phase2_contract.events.InferenceBatchCommitted()
    logs = ev.get_logs(fromBlock=from_block, toBlock=to_block, argument_filters={"modelId": model_id_b})
    if not logs:
        return None, None

    last = logs[-1]["args"]
    # args likely contain indexed modelId, batchRoot + non-indexed count, batchCID
    # web3 already decodes indexed params into args for us
    cid = last.get("batchCID", None)
    root_b32 = last.get("batchRoot", None)
    root_hex = "0x" + root_b32.hex() if isinstance(root_b32, (bytes, bytearray)) else (root_b32 if isinstance(root_b32, str) else None)
    return cid, root_hex

def _latest_batch_from_txhash(phase2_contract, commit_tx_hash: str):
    """
    Fallback: parse a known commit tx receipt without having the event in the ABI.
    You must provide INFER_COMMIT_TX (env or variable) that corresponds to a commitInferenceBatch tx.
    """
    if not commit_tx_hash:
        return None, None
    r = w3.eth.get_transaction_receipt(commit_tx_hash)
    # topic0 = keccak("InferenceBatchCommitted(bytes32,bytes32,uint256,string)")
    topic0 = Web3.keccak(text="InferenceBatchCommitted(bytes32,bytes32,uint256,string)")
    addr = phase2_contract.address.lower()

    for lg in r["logs"]:
        if lg["address"].lower() != addr:
            continue
        if len(lg["topics"]) < 3 or lg["topics"][0] != topic0:
            continue
        # topics[1] = modelId (indexed), topics[2] = batchRoot (indexed)
        batch_root_hex = "0x" + lg["topics"][2].hex()[2:].rjust(64, "0")
        # data encodes (uint256 count, string batchCID)
        # decode ABI-encoded bytes; strip the 0x prefix
        data_bytes = bytes.fromhex(lg["data"][2:])
        try:
            count, batch_cid = abi_decode(["uint256", "string"], data_bytes)
            return batch_cid, batch_root_hex
        except Exception:
            # If decode fails, still return the root; CID unknown
            return None, batch_root_hex
    return None, None

def latest_phase2_batch(model_id_b):
    # Path A: try event via ABI first
    try:
        return _latest_batch_from_event(phase2, model_id_b, from_block=0, to_block="latest")
    except Exception as e:
        print("Batch via ABI event not available:", e)

    # Path B: fallback via known tx hash (optional)
    commit_tx = os.getenv("INFER_COMMIT_TX") or globals().get("INFER_COMMIT_TX")
    cid, root_hex = _latest_batch_from_txhash(phase2, commit_tx) if commit_tx else (None, None)
    if cid or root_hex:
        return cid, root_hex

    # Path C: no info; return zeros
    return None, None




In [None]:
cid, batch_root_hex = latest_phase2_batch(model_id_bytes)
batch_root_b32 = bytes.fromhex(batch_root_hex[2:]) if batch_root_hex else b"\x00"*32
xai_cid = norm_ipfs(cid)  # handles None, raw CID, or ipfs://CID
rcpt = send_tx(phase3.functions.storeContentHash, model_id_bytes, content_hash, batch_root_b32, xai_cid)
print("storeContentHash tx:", rcpt.transactionHash.hex())

**Verify later**

In [None]:
# === PHASE 3 — verify a content string against the chain (via event logs) ===
from web3 import Web3
from eth_abi import decode as abi_decode

def verify_output_onchain(raw_text: str, model_id_hint: bytes = None, from_block=0, to_block="latest"):
    h = Web3.keccak(text=raw_text)  # bytes32
    print("contentHash:", h.hex())

    # Try ABI-decoded event first (simplest)
    try:
        ev = phase3.events.ContentStored()
        arg_filters = {"contentHash": h}
        if model_id_hint is not None:
            arg_filters["modelId"] = model_id_hint
        logs = ev.get_logs(fromBlock=from_block, toBlock=to_block, argument_filters=arg_filters)
        if not logs:
            print("Exists: False")
            return False

        a = logs[-1]["args"]
        print("Exists   :", True)
        print("submitter:", a.get("submitter"))
        print("modelId  :", "0x" + a.get("modelId").hex())
        print("batchRoot:", "0x" + a.get("batchRoot").hex())
        print("xaiCID   :", a.get("xaiCID"))
        print("time     :", a.get("timestamp"))
        return True

    except Exception as e:
        print("ABI event path failed, falling back to raw log scan:", e)

    # Fallback: raw topic scan (works even if your local ABI is missing the event)
    # event ContentStored(bytes32 indexed modelId, bytes32 indexed contentHash, bytes32 indexed batchRoot, address submitter, string xaiCID, uint64 timestamp)
    topic0 = Web3.keccak(text="ContentStored(bytes32,bytes32,bytes32,address,string,uint64)")
    topics = [topic0, None, h, None]  # filter by contentHash only
    if model_id_hint is not None:
        topics[1] = model_id_hint  # stricter filter if you want

    flt = {
        "fromBlock": from_block,
        "toBlock": to_block,
        "address": phase3.address,
        "topics": topics,
    }
    logs = w3.eth.get_logs(flt)
    if not logs:
        print("Exists: False")
        return False

    lg = logs[-1]
    # Indexed topics order: [topic0, modelId, contentHash, batchRoot]
    model_id_b32 = bytes.fromhex(lg["topics"][1].hex()[2:])
    batch_root_b32 = bytes.fromhex(lg["topics"][3].hex()[2:])

    # Data encodes (address submitter, string xaiCID, uint64 timestamp)
    data_bytes = bytes.fromhex(lg["data"][2:])
    submitter, xaiCID, ts = abi_decode(["address", "string", "uint64"], data_bytes)

    print("Exists   :", True)
    print("submitter:", submitter)
    print("modelId  :", "0x" + model_id_b32.hex())
    print("batchRoot:", "0x" + batch_root_b32.hex())
    print("xaiCID   :", xaiCID)
    print("time     :", ts)
    return True

# Usage:
verify_output_onchain(output_text, model_id_hint=model_id_bytes, from_block=0)


**Build & sign an EIP-712 receipt**

In [None]:
import os
# Use the **model owner** key or an **allowed publisher** key
os.environ["EOA_PRIV_KEY"] = "<insert_pk>"   # preferred
# or:
# os.environ["PUB_PRIV_KEY"] = "0x<your_private_key_hex>"


In [None]:
#Sanity check before running the eip-712 cells
assert all(k in globals() for k in ["w3","acct","phase2","model_id_bytes"]), "Run Phase-2 cells first"
print("leaf count:", len(leaves) if "leaves" in globals() else "no leaves")
print("batch root:", batch_root_hex if "batch_root_hex" in globals() else "missing")


In [None]:
# === PHASE 3 — EIP-712 receipt signer (publisher = your EOA) ===
import os, json, time
from eth_account import Account
from eth_account.messages import encode_structured_data
from web3 import Web3

# --- helpers ---
def _hex32(x: str | bytes | None) -> str:
    """Return 0x-prefixed 32-byte hex string or '0x' * 66 zero if None/empty."""
    if x is None:
        return "0x" + "00"*32
    if isinstance(x, (bytes, bytearray)):
        h = x.hex()
        if not h.startswith("0x"):
            h = "0x" + h
        return h if len(h) == 66 else ("0x" + h[2:].rjust(64, "0"))
    if isinstance(x, str):
        h = x if x.startswith("0x") else "0x"+x
        if h == "0x":
            return "0x" + "00"*32
        # strip possible double 0x (seen before)
        if h.startswith("0x0x"):
            h = "0x" + h[4:]
        return h if len(h) == 66 else ("0x" + h[2:].rjust(64, "0"))
    raise TypeError("unsupported type for _hex32")

def _to_bytes32(hx: str) -> bytes:
    hx = _hex32(hx)
    return bytes.fromhex(hx[2:])

def _keccak_text(s: str) -> bytes:
    return Web3.keccak(text=s)

# --- publisher key (use the same EOA that owns the model, or an allowed publisher) ---
PRIVATE_KEY = os.environ.get("EOA_PRIV_KEY") or os.environ.get("PUB_PRIV_KEY")
assert PRIVATE_KEY, "Set EOA_PRIV_KEY (or PUB_PRIV_KEY) in env"
if not PRIVATE_KEY.startswith("0x"):
    PRIVATE_KEY = "0x" + PRIVATE_KEY
acct = Account.from_key(PRIVATE_KEY)
publisher = acct.address

# --- chain + contract addrs ---
chain_id       = w3.eth.chain_id
verifying_addr = phase3.address          # IMPORTANT: OutputAuthentication, not Phase-2
registry_addr  = phase2.address          # FYI: used for provenance display only

# --- gather Phase-2 bindings we already discovered earlier ---
MODEL_ID_HEX   = _hex32(model_id_bytes)                                     # bytes32
WEIGHTS_HEX    = _hex32(globals().get("WEIGHTS_HASH") or globals().get("weights_hash"))
BATCH_ROOT_HEX = _hex32(globals().get("batch_root_hex"))
# ✅ Normalize batch CID to avoid ipfs://ipfs://... and accept raw CIDs too
raw_cid = globals().get("batch_cid") or ""
BATCH_CID_STR = norm_ipfs(raw_cid)  # requires norm_ipfs(...) helper from the shared helpers cell

# --- compute a leaf for THIS output (unsalted; fine for demo and signature path) ---
# If you want reveal/non-reveal salts to match your Phase-2 batch exactly, replace these three lines
# with the salted hashes you used when you built/committed the batch.
input_hash  = _keccak_text(prompt_text)           # bytes32
output_hash = _keccak_text(output_text)           # bytes32
xai_hash    = b"\x00"*32                          # or Web3.keccak(text=json.dumps(xai_obj, separators=(',',':')))
leaf_bytes  = Web3.solidity_keccak(
    ["bytes32","bytes32","bytes32"],
    [input_hash, output_hash, xai_hash]
)
LEAF_HEX = _hex32(leaf_bytes)

# --- tx hash for the batch (optional); zero if unknown here ---
TX_HASH_HEX = _hex32(globals().get("rcpt_batch").transactionHash.hex() if "rcpt_batch" in globals() else None)

# --- build the typed data (must match Phase-3's RECEIPT_TYPEHASH) ---
typed_data = {
  "domain": {
    "name": "LLMOutputReceipt",
    "version": "1",
    "chainId": chain_id,
    "verifyingContract": verifying_addr,   # MUST be phase3.address
  },
  "primaryType": "Receipt",
  "types": {
    "EIP712Domain": [
      {"name":"name","type":"string"},
      {"name":"version","type":"string"},
      {"name":"chainId","type":"uint256"},
      {"name":"verifyingContract","type":"address"}
    ],
    "Receipt": [
      {"name":"modelId","type":"bytes32"},
      {"name":"weightsHash","type":"bytes32"},
      {"name":"batchRoot","type":"bytes32"},
      {"name":"batchCID","type":"string"},
      {"name":"txHash","type":"bytes32"},
      {"name":"index","type":"uint256"},
      {"name":"leaf","type":"bytes32"},
      {"name":"publisher","type":"address"},
      {"name":"timestamp","type":"uint64"}
    ]
  },
  "message": {
    "modelId":    _to_bytes32(MODEL_ID_HEX),
    "weightsHash":_to_bytes32(WEIGHTS_HEX),
    "batchRoot":  _to_bytes32(BATCH_ROOT_HEX),
    "batchCID":   BATCH_CID_STR or "",            # string hashed per EIP-712 in the contract
    "txHash":     _to_bytes32(TX_HASH_HEX),
    "index":      int(0),                         # set the actual leaf index if you have it
    "leaf":       _to_bytes32(LEAF_HEX),
    "publisher":  publisher,
    "timestamp":  int(time.time()),
  }
}

msg = encode_structured_data(primitive=typed_data)
sig = Account.sign_message(msg, private_key=PRIVATE_KEY)

receipt = {
    "schema": "llm.receipt.v1",
    "mode": "non-reveal",             # or "reveal"
    "chainId": chain_id,
    "registry": registry_addr,        # for human readers; NOT in EIP-712 domain
    "modelId": MODEL_ID_HEX,
    "weightsHash": WEIGHTS_HEX,
    "batchRoot": BATCH_ROOT_HEX,
    "batchCID": BATCH_CID_STR or "",
    "txHash": TX_HASH_HEX,
    "index": 0,
    "leaf": LEAF_HEX,
    "publisher": publisher,
    "timestamp": typed_data["message"]["timestamp"],
    "eip712": {
        "domain": typed_data["domain"],
        "types":  typed_data["types"],
        "message": {
            # serialize back to hex strings for the JSON blob
            "modelId": MODEL_ID_HEX,
            "weightsHash": WEIGHTS_HEX,
            "batchRoot": BATCH_ROOT_HEX,
            "batchCID": typed_data["message"]["batchCID"],
            "txHash": TX_HASH_HEX,
            "index": typed_data["message"]["index"],
            "leaf": LEAF_HEX,
            "publisher": publisher,
            "timestamp": typed_data["message"]["timestamp"],
        },
        "signature": sig.signature.hex(),
    }
}

print("Signed receipt for index", receipt["index"])
print("publisher:", publisher)


In [None]:
# === PHASE 3 — verify EIP-712 receipt (off-chain) + optional on-chain cross-check ===
from eth_account.messages import encode_typed_data  # use this (new) instead of encode_structured_data
from eth_account import Account
from web3 import Web3

def _to_b32(hx: str) -> bytes:
    """Convert 0x-prefixed 32-byte hex to bytes32; returns zero32 if '0x'."""
    if not isinstance(hx, str):
        raise TypeError("expected hex string")
    if hx == "0x":
        return b"\x00" * 32
    if not hx.startswith("0x"):
        hx = "0x" + hx
    return bytes.fromhex(hx[2:])

def verify_receipt(receipt: dict):
    # Rebuild typed data from the JSON blob
    td = {
        "domain": receipt["eip712"]["domain"],
        "primaryType": "Receipt",
        "types": receipt["eip712"]["types"],
        "message": receipt["eip712"]["message"].copy(),  # shallow copy, we'll mutate types
    }

    # Convert hex strings to bytes for all bytes32 fields (required by encoder)
    for k in ("modelId", "weightsHash", "batchRoot", "txHash", "leaf"):
        td["message"][k] = _to_b32(td["message"][k])

    # The other fields already have the correct python types:
    # - batchCID: str
    # - index: int
    # - publisher: "0x..." address (string is OK)
    # - timestamp: int

    # Encode per EIP-712 and recover signer
    signable = encode_typed_data(full_message=td)
    sig_hex = receipt["eip712"]["signature"]
    sig_bytes = bytes.fromhex(sig_hex[2:] if sig_hex.startswith("0x") else sig_hex)
    signer = Account.recover_message(signable, signature=sig_bytes)
    ok_sig = (signer.lower() == receipt["publisher"].lower())

    result = {"ok_signature": ok_sig, "signer": signer}

    # ---- Optional: cross-check with the Phase-3 contract's verifyReceipt ----
    try:
        r_msg = receipt["eip712"]["message"]
        r_tuple = (
            _to_b32(receipt["modelId"]),
            _to_b32(receipt["weightsHash"]),
            _to_b32(receipt["batchRoot"]),
            r_msg["batchCID"],
            _to_b32(receipt["txHash"]),
            int(r_msg["index"]),
            _to_b32(receipt["leaf"]),
            receipt["publisher"],
            int(r_msg["timestamp"]),
        )
        ok_onchain, signer_onchain = phase3.functions.verifyReceipt(r_tuple, sig_bytes).call()
        result.update({"ok_onchain": ok_onchain, "signer_onchain": signer_onchain})
    except Exception as e:
        result.update({"ok_onchain": None, "onchain_error": str(e)})

    return result

verify_receipt(receipt)
