In [1]:
import random
import math
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import pytorch_lightning as pl

import json

# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

  warn(


In [2]:
class BaseDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, split=0.8, *args, **kwargs):
        super().__init__()
        self.contexts, self.questions, self.answers = self.get_dataset(*args, **kwargs)
        self.split = int(len(self.contexts) * split)
        self.batch_size = batch_size

    def train_dataloader(self):
        train_contexts = self.contexts[:self.split]
        train_questions = self.questions[:self.split]
        train_answers = self.answers[:self.split]
        return torch.utils.data.DataLoader(list(zip(train_contexts, train_questions, train_answers)), batch_size=self.batch_size)

    def val_dataloader(self):
        val_contexts = self.contexts[self.split:]
        val_questions = self.questions[self.split:]
        val_answers = self.answers[self.split:]
        return torch.utils.data.DataLoader(list(zip(val_contexts, val_questions, val_answers)), batch_size=self.batch_size)
  
class CustomDataModule(BaseDataModule):
    def get_dataset(self):
      context, question, answer = self.preprocess_custom_data("data.json")
      return context, question, answer
    
    def preprocess_custom_data(self, json_file):
      with open(json_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
    
        contexts = []
        questions = []
        answers = []
    
        for entry in data:
          context = entry['context']
          question = entry['question']
          answer_text = entry['answers']['text'][0]
          answer_start = entry['answers']['answer_start'][0]
        
          contexts.append(context)
          questions.append(question)
          answers.append((answer_text, answer_start))
    
      return contexts, questions, answers

In [3]:
def attention(queries, keys, values):
  d = queries.shape[-1]
  scores = torch.matmul(queries, keys.transpose(-2,-1))/math.sqrt(d)
  attention_weights = F.softmax(scores, dim=-1)
  return torch.matmul(attention_weights, values)

class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_dim, self.num_heads = embed_dim, num_heads
    assert embed_dim % num_heads == 0
    self.projection_dim = embed_dim // num_heads
    
    self.W_q = nn.Linear(embed_dim, embed_dim)
    self.W_k = nn.Linear(embed_dim, embed_dim)
    self.W_v = nn.Linear(embed_dim, embed_dim)
    self.W_o = nn.Linear(embed_dim, embed_dim)

  def transpose(self, x):
    x = x.reshape(x.shape[0], x.shape[1], self.num_heads, self.projection_dim)
    return x.permute(0, 2, 1, 3)
  
  def transpose_output(self, x):
    x = x.permute(0, 2, 1, 3)
    return x.reshape(x.shape[0], x.shape[1], self.embed_dim)
    
  def forward(self, q, k, v):
    q = self.transpose(self.W_q(q))
    k = self.transpose(self.W_k(k))
    v = self.transpose(self.W_v(v))
    output = attention(q, k, v)
    return self.W_o(self.transpose_output(output))
  
class TransformerBlock(nn.Module):
  def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
    super(TransformerBlock, self).__init__()
    self.att = MultiHeadAttention(embed_dim, num_heads)
    self.ffn = nn.Sequential(
      nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim)
    )
    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.dropout = nn.Dropout(rate)
    
  def forward(self, x):
    x = self.layernorm1(x + self.dropout(self.att(x, x, x)))
    x = self.layernorm2(x + self.dropout(self.ffn(x)))
    return x
  
class TokenAndPositionEmbedding(nn.Module):
  def __init__(self, maxlen, vocab_size, embed_dim):
    super(TokenAndPositionEmbedding, self).__init__()
    self.token_emb = nn.Embedding(vocab_size, embed_dim)
    self.pos_emb = nn.Embedding(maxlen, embed_dim)
    
  def forward(self, x):
    pos = torch.arange(0, x.size(1), dtype=torch.int32, device=x.device)
    return self.token_emb(x) + self.pos_emb(pos).view(1, x.size(1), -1)

In [4]:
import torch
import torch.nn.functional as F

def compute_question_answering_loss(model_output, answer_spans):
    # extract predicted start / end position logits
    start_logits, end_logits = model_output
    
    # unpack answer spans
    answer_text, answer_start = answer_spans

    # convert answer spans to tensor
    answer_start = torch.tensor(answer_start, dtype=torch.long).to(start_logits.device)
    
    # calculate cross-entropy loss for start / end positions
    start_loss = F.cross_entropy(start_logits, answer_start)
    
    # sum / average loss as needed
    total_loss = start_loss

    return total_loss

def evaluate_question_answering(predictions, references):
    """
    Compute Exact Match (EM) and F1-score for question answering.

    Args:
        predictions (list): List of predicted answer strings.
        references (list): List of reference answer strings.

    Returns:
        em (float): Exact Match score.
        f1 (float): F1-score.
    """
    assert len(predictions) == len(references), "Number of predictions must match number of references."

    em_count = 0
    f1_total = 0

    for pred, ref in zip(predictions, references):
        # Tokenize predicted and reference answers
        pred_tokens = pred.lower().split()
        ref_tokens = ref.lower().split()

        common_tokens = set(pred_tokens) & set(ref_tokens)
        if not common_tokens:
            # No common tokens, EM score is 0
            em_count += 0
            f1_total += 0
        else:
            # Compute F1-score
            precision = len(common_tokens) / len(pred_tokens)
            recall = len(common_tokens) / len(ref_tokens)
            f1 = (2 * precision * recall) / (precision + recall)
            f1_total += f1

            # Exact Match (EM) is 1 if F1-score is 1, else 0
            em_count += int(f1 == 1)

    em = em_count / len(predictions)
    f1 = f1_total / len(predictions)

    return em, f1

class LittleTransformer(pl.LightningModule):
  def __init__(self, seq_len=6, max_value=10, layer_count=2, embed_dim=128, num_heads=4, ff_dim=32):
    super().__init__()
    self.max_value = max_value
    self.model = nn.Sequential(
      TokenAndPositionEmbedding(seq_len, max_value, embed_dim),
      *[TransformerBlock(embed_dim, num_heads, ff_dim) for x in range(layer_count)],
      nn.Linear(embed_dim, max_value),
      nn.LogSoftmax(dim=-1))
    
  def forward(self, context, question):
    input_embeddings = torch.cat((self.model.token_emb(context), self.model.token_emb(question)), dim=1)
    return self.model(input_embeddings)
  
  def training_step(self, batch, batch_idx):
    context, question, answer = batch
    output = self.model(context, question)
    loss = compute_question_answering_loss(output, answer)
    self.log("train_loss", loss)
    return loss
  
  def validation_step(self, val_batch, batch_idx):
    context, question, answer = val_batch
    pred = self.model(context, question).argmax(dim=2)
    evaluation_metrics = evaluate_question_answering(pred, answer)
    self.log("val_metrics", evaluation_metrics, prog_bar=True)
  
  def configure_optimizers(self):
    if self.device.type == 'cuda':
      import apex
      return apex.optimizers.FusedAdam(self.parameters(), lr=3e-4)
    else:
      return torch.optim.Adam(self.parameters(), lr=3e-4)

In [5]:
data = CustomDataModule(batch_size=64)
model = LittleTransformer(seq_len=6)
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=0)
trainer.fit(model, data)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 153 K 
-------------------------------------
153 K     Trainable params
0         Non-trainable params
153 K     Total params
0.613     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\sabri\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


TypeError: Sequential.forward() takes 2 positional arguments but 3 were given

## EZKL 

In [None]:

import os 

model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')



In [None]:

import json 
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = torch.ones([1, 6], dtype=torch.long)
x = x.reshape([1, 6])

print(x)

# Flips the neural net into inference mode
model.eval()
model.to('cpu')

    # Export the model
torch.onnx.export(model,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data_json = dict(input_data = [data_array])

print(data_json)

    # Serialize data into file:
json.dump( data_json, open(data_path, 'w' ))


In [None]:
import ezkl 

!RUST_LOG=trace
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path)
assert res == True

In [None]:


res = await ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
assert res == True


In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [None]:
# srs path
res = ezkl.get_srs(srs_path, settings_path)

In [None]:
# now generate the witness file 
witness_path = "gan_witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:
res = ezkl.mock(witness_path, compiled_model_path)
assert res == True

In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [None]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

proof = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "single",
    )

print(proof)
assert os.path.isfile(proof_path)

In [None]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")

In [None]:
sol_code_path = os.path.join('Verifier.sol')
abi_path = os.path.join('Verifier.abi')

res = ezkl.create_evm_verifier(
        vk_path,
        srs_path,
        settings_path,
        sol_code_path,
        abi_path
    )

assert res == True
assert os.path.isfile(sol_code_path)

In [None]:
onchain_input_array = []

# using a loop
# avoiding printing last comma
formatted_output = "["
for i, value in enumerate(proof["instances"]):
    for j, field_element in enumerate(value):
        onchain_input_array.append(ezkl.vecu64_to_felt(field_element))
        formatted_output += str(onchain_input_array[-1])
        if j != len(value) - 1:
            formatted_output += ", "
    formatted_output += "]"

# This will be the values you use onchain
# copy them over to remix and see if they verify
# What happens when you change a value?
print("pubInputs: ", formatted_output)
print("proof: ", "0x" + proof["proof"])