In [105]:
from transformers import GPT2Model, GPT2Tokenizer, GPT2LMHeadModel
model_name = "gpt2"  # You can change this to any GPT model available on Hugging Face
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

In [106]:
import data
import torch
import typing

def topk_gradient(model: torch.nn.Module, topk_percent: float) -> typing.Dict[str, typing.Tuple[torch.Tensor, torch.Tensor]]:
    gradient_data = {}
    for name, parameter in model.named_parameters():
        if parameter.grad is not None:
            total_elements = parameter.grad.numel()
            topk_elements = int(total_elements * topk_percent)
            values, indices = torch.topk(parameter.grad.abs().flatten(), topk_elements)
            # Cast indices to int32 and values to float32
            indices = indices.to(torch.int32)
            values = values.to(torch.float32)
            gradient_data[name] = (indices, values)
    return gradient_data

def create_proof( 
        model: torch.nn.Module,
        tokenizer: 'tokenizer',
        pages: typing.List[int], 
        batch_size: int, 
        sequence_length: int,
        device: str = 'cpu',
        topk_percent: float = 0.01
    ) -> typing.Dict[str, typing.Tuple[torch.Tensor, torch.Tensor]]:
        model.to( device )
        batches = list(
            data.SubsetFalconLoader(
                tokenizer = tokenizer,
                batch_size = batch_size, 
                sequence_length = sequence_length,
                rows = pages
            )
        )
        model.zero_grad()
        for batch in batches:
            inputs = batch.to(device)
            outputs = model( inputs, labels=inputs )
            outputs.loss /= len(batches)
            outputs.loss.backward()
            break
        
        gradient_data = topk_gradient(model, topk_percent)
        return gradient_data
    
def check_equality(
        proof_A: typing.Dict[str, typing.Tuple[torch.Tensor, torch.Tensor]],
        proof_B: typing.Dict[str, typing.Tuple[torch.Tensor, torch.Tensor]],
    ):
    for key, (indices_A, values_A) in proof_A.items():
        if key not in proof_B:
            return False
        indices_B, values_B = proof_B[key]
        # Ensure tensors are of the same type before comparison
        indices_A = indices_A.to(torch.int32)
        values_A = values_A.to(torch.float32)
        indices_B = indices_B.to(torch.int32)
        values_B = values_B.to(torch.float32)
        if not (torch.equal(indices_A, indices_B) and torch.allclose(values_A, values_B)):
            return False
    return True


In [110]:
import random

def accumulate_proofs(model, proofs):
    """
    Accumulate multiple proofs by scatter adding them onto the grad values of the model.
    
    Args:
    - model (torch.nn.Module): The model whose gradients will be updated.
    - proofs (List[Dict[str, Tuple[torch.LongTensor, torch.FloatTensor]]]): A list of proofs, 
      where each proof is a dictionary mapping parameter names to tuples of indices and values.
    """    
    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        # Initialize a tensor of zeros with the same shape as the parameter's gradient
        if param.grad is None: param.grad = torch.zeros_like(param.data)
        # Flatten the gradient to apply scatter add
        grad_flat = param.grad.view(-1)
        # Iterate through each proof and scatter add the values onto the flattened gradient
        for proof in proofs:
            if name not in proof: continue
            indices, values = proof[name]
            # Since the gradient is flattened, indices do not need to be unraveled
            grad_flat.scatter_add_(0, indices.to(torch.long), values)
        # Unflatten the gradient back to its original shape
        param.grad = grad_flat.view_as(param.grad)
      
n = 10  
proofs = []
from tqdm import tqdm
for i in tqdm(range(n)):
    page = random.randint(0, data.SubsetFalconLoader.max_pages)
    batch_size = 1
    sequence_length = 50
    device = 'cpu'
    topk_percent = 0.01
    proofs.append(create_proof( 
        model,
        tokenizer,
        pages = [ page ],
        batch_size = batch_size,
        sequence_length = sequence_length,
        device = device,
        topk_percent = topk_percent
    ))
    
accumulate_proofs( model, proofs )


100%|██████████| 10/10 [00:21<00:00,  2.17s/it]


In [111]:
import hashlib

def compute_proof_hash(proof, pages, batch_size, sequence_length, topk_percent):
    """
    Compute a hash of the proof and the parameters used to generate the gradient proof.
    
    Args:
    - proofs (Dict[str, Tuple[torch.LongTensor, torch.FloatTensor]]): proof
    - pages (List[int]): The pages used to generate the proofs.
    - batch_size (int): The batch size used.
    - sequence_length (int): The sequence length used.
    - topk_percent (float): The topk percent used.
    
    Returns:
    - str: The computed hash.
    """
    # Convert all inputs to string and concatenate
    concatenated_inputs = ''.join(proof) + \
                          ''.join([str(page) for page in pages]) + \
                          str(batch_size) + str(sequence_length) + str(topk_percent)
    # Encode the concatenated string
    encoded_inputs = concatenated_inputs.encode()
    # Compute the hash
    proof_hash = hashlib.sha256(encoded_inputs).hexdigest()
    return proof_hash

# Example usage
pages = [random.randint(0, data.SubsetFalconLoader.max_pages) for _ in range(2)]
hash = compute_proof_hash(proofs[0], pages, batch_size, sequence_length, topk_percent)
seal = {
    'hash': hash,
    'pages': pages,
    'batch_size': batch_size,
    'sequence_length': sequence_length,
    'topk_percent': topk_percent
}


In [109]:
local_proof = create_proof( 
    model,
    tokenizer,
    pages = seal['pages'],
    batch_size = seal['batch_size'],
    sequence_length = seal['sequence_length'],
    device = 'cpu',
    topk_percent = seal['topk_percent']
)
recomputed_hash = compute_proof_hash( local_proof, seal['pages'], seal['batch_size'], seal['sequence_length'], seal['topk_percent'])
hash_check = recomputed_hash == seal['hash']
hash_check


True