In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import hashlib
import random
import copy
from torch.utils.data import DataLoader, TensorDataset

In [8]:
# Hash function for tensors that's tolerant to small numerical differences
def hash_tensor(tensor, precision=5):
    """Hash a tensor with reduced precision to handle numerical differences."""
    rounded = torch.round(tensor * 10**precision) / (10**precision)
    tensor_bytes = rounded.detach().cpu().numpy().tobytes()
    return hashlib.md5(tensor_bytes).hexdigest()

In [9]:
# Wrapper for proof of matrix multiplication
class MatMulProofWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.log = []  # Store operation logs
        self.input_samples = {}  # Store inputs for verification
        self.model_states = {}  # Store model states for verification
        self.current_input = None
        self.sample_counter = 0

        # Dictionary of modules to track
        self.module_dict = {str(m): m for _, m in model.named_modules() if isinstance(m, nn.Linear)}

        # Register hooks on all linear layers
        for module in self.module_dict.values():
            module.register_forward_hook(
                lambda m, inp, out: self._log_matmul(m, inp[0], out)
            )

    def _log_matmul(self, module, inputs, outputs):
        """Log hash of inputs and outputs of matrix multiplication."""
        record = {
            'id': len(self.log),
            'module': str(module),
            'input_hash': hash_tensor(inputs),
            'output_hash': hash_tensor(outputs),
            'sample_id': self.current_sample_id if hasattr(self, 'current_sample_id') else None
        }
        self.log.append(record)

    def forward(self, x):
        """Forward pass with tracking."""
        self.current_input = x

        # Randomly decide to save state for verification (20% chance)
        if random.random() < 0.2:
            self.current_sample_id = self.sample_counter
            self.sample_counter += 1

            # Save input
            self.input_samples[self.current_sample_id] = x.clone()

            # Save model state for this sample
            model_state = {}
            for name, module in self.module_dict.items():
                if hasattr(module, 'weight') and hasattr(module, 'bias'):
                    model_state[name] = {
                        'weight': module.weight.clone().detach(),
                        'bias': module.bias.clone().detach() if module.bias is not None else None
                    }
            self.model_states[self.current_sample_id] = model_state
        else:
            self.current_sample_id = None

        return self.model(x)

    def get_verification_data(self):
        """Get data needed for verification."""
        return {
            'log': self.log,
            'inputs': self.input_samples,
            'model_states': self.model_states
        }

In [10]:
# Verification function
def verify_operations(original_model, verification_data, num_samples=3):
    """Verify a sample of operations from the log."""
    log = verification_data['log']
    inputs = verification_data['inputs']
    model_states = verification_data['model_states']

    # Find operations with saved state
    verifiable_samples = set(model_states.keys())
    verifiable_ops = [op for op in log if op['sample_id'] in verifiable_samples]

    if not verifiable_ops:
        return "No operations with saved states found for verification."

    # Group operations by sample_id
    ops_by_sample = {}
    for op in verifiable_ops:
        sample_id = op['sample_id']
        if sample_id not in ops_by_sample:
            ops_by_sample[sample_id] = []
        ops_by_sample[sample_id].append(op)

    # Sample a few sample_ids to verify
    sample_size = min(num_samples, len(ops_by_sample))
    sample_ids = random.sample(list(ops_by_sample.keys()), sample_size)

    results = []

    # Create verification model (to avoid modifying original)
    verification_model = copy.deepcopy(original_model)

    # Create module dictionary
    module_dict = {str(m): m for _, m in verification_model.named_modules() if isinstance(m, nn.Linear)}

    for sample_id in sample_ids:
        # Get the saved model state
        saved_state = model_states[sample_id]

        # Restore model state
        for module_name, state in saved_state.items():
            if module_name in module_dict:
                module = module_dict[module_name]
                with torch.no_grad():
                    module.weight.copy_(state['weight'])
                    if state['bias'] is not None and module.bias is not None:
                        module.bias.copy_(state['bias'])

        # Get input for this sample
        input_tensor = inputs[sample_id]

        # Set up verification hooks for each operation in this sample
        hooks = []
        verification_results = {}

        for op in ops_by_sample[sample_id]:
            module_str = op['module']
            if module_str not in module_dict:
                results.append({
                    'id': op['id'],
                    'module': module_str,
                    'verified': False,
                    'error': 'Module not found'
                })
                continue

            module = module_dict[module_str]

            # Create unique key for this operation
            op_key = f"{op['id']}"
            verification_results[op_key] = {'verified': False}

            # Define hook for this operation
            def make_hook(op_key, op_data):
                def hook_fn(m, inp, out):
                    input_hash = hash_tensor(inp[0])
                    output_hash = hash_tensor(out)

                    input_match = input_hash == op_data['input_hash']
                    output_match = output_hash == op_data['output_hash']
                    verified = input_match and output_match

                    verification_results[op_key] = {
                        'id': op_data['id'],
                        'module': op_data['module'],
                        'verified': verified,
                        'input_match': input_match,
                        'output_match': output_match
                    }
                return hook_fn

            # Register hook
            hook = module.register_forward_hook(make_hook(op_key, op))
            hooks.append(hook)

        # Run forward pass with saved input
        with torch.no_grad():
            verification_model(input_tensor)

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Add verification results to overall results
        results.extend(list(verification_results.values()))

    # Summarize results
    success_count = sum(1 for r in results if r.get('verified', False))
    return f"Verified {success_count}/{len(results)} operations successfully.", results

In [11]:
# Simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [12]:
# Main function to demonstrate the entire process
def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    # Create a simple dataset
    X = torch.randn(100, 10)
    y = torch.randint(0, 2, (100,))
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

    # Create model and wrapper
    original_model = SimpleModel()
    wrapped_model = MatMulProofWrapper(original_model)

    # Train the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(wrapped_model.parameters(), lr=0.01)

    print("Training model...")
    for batch_idx, (data, target) in enumerate(dataloader):
        optimizer.zero_grad()
        output = wrapped_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Only train on a few batches for demonstration
        if batch_idx >= 5:
            break

    print("Training complete!")

    # Get verification data
    verification_data = wrapped_model.get_verification_data()
    print(f"Collected {len(verification_data['log'])} matrix multiplication operations")
    print(f"Stored {len(verification_data['inputs'])} checkpoints for verification")

    # Verify operations
    print("\nVerifying operations...")
    summary, results = verify_operations(original_model, verification_data)
    print(summary)

    # Print detailed results
    for result in results:
        if isinstance(result, dict):  # Skip non-dict entries
            status = "✓" if result.get('verified', False) else "✗"
            print(f"Operation {result.get('id')} ({result.get('module', 'unknown')}): {status}")

            if not result.get('verified', False):
                if 'error' in result:
                    print(f"  Error: {result['error']}")
                else:
                    if not result.get('input_match', True):
                        print("  - Input hash mismatch")
                    if not result.get('output_match', True):
                        print("  - Output hash mismatch")

    print("\nVerification complete!")

if __name__ == "__main__":
    main()

Training model...
Training complete!
Collected 12 matrix multiplication operations
Stored 1 checkpoints for verification

Verifying operations...
Verified 2/2 operations successfully.
Operation 2 (Linear(in_features=10, out_features=20, bias=True)): ✓
Operation 3 (Linear(in_features=20, out_features=2, bias=True)): ✓

Verification complete!
