# üïµÔ∏è Model Stealing: An Extraction Attack Demo

**Core Concept**: Model extraction is a type of attack where an adversary queries a machine learning model's API to build a surrogate (replica) model. This is **IP theft through information leakage**, not database hacking.

## üéØ The IP Leak Problem
1.  Your model makes predictions continuously
2.  Each prediction reveals information about internal decision boundaries
3.  Attackers collect these predictions to train a replica
4.  No need to access your training data or model weights
5.  Result: Attacker steals years of R&D for pennies

## üí∞ Economics
-   **Small models**: ~500-1,000 queries ($0.50-$1)
-   **Medium models**: ~5,000-10,000 queries ($5-$10)
-   **Large models**: ~100,000+ queries ($100-$1,000)
-   **Your development cost**: Millions of dollars

## üìã Demo Scenario
We simulate:
1.  **Victim**: A company with a digit classifier (MNIST) API
2.  **Attacker**: A competitor who wants to steal the model
3.  **Attack**: Query the API and train a surrogate model
4.  **Success**: Surrogate achieves 95%+ fidelity to the victim

## üõ†Ô∏è Step 1: Setup & Data Loading

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

# Set random seed for reproducibility
np.random.seed(42)

# Load MNIST-like digits dataset (8x8 grayscale images)
digits = load_digits()
X, y = digits.data, digits.target

print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features (8x8 pixels)")
print(f"Classes: {np.unique(y)} (digits 0-9)")

# Split into victim's training set and holdout test set
X_victim_train, X_test, y_victim_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

print(f"\nVictim training set: {X_victim_train.shape[0]} samples")
print(f"Holdout test set: {X_test.shape[0]} samples")

## üè¢ Step 2: Train the Victim Model (The Target)
This represents the proprietary model that a company has spent significant resources developing.

In [None]:
# Train victim model (Random Forest for demonstration)
print("Training victim model (this is the IP to be stolen)...")
victim_model = RandomForestClassifier(n_estimators=100, random_state=42)
victim_model.fit(X_victim_train, y_victim_train)

# Evaluate victim model
y_victim_pred = victim_model.predict(X_test)
victim_accuracy = accuracy_score(y_test, y_victim_pred)

print(f"\n‚úÖ Victim model trained!")
print(f"Victim test accuracy: {victim_accuracy:.4f}")
print(f"\nThis model represents millions in R&D investment.")

## üîå Step 3: Simulate the API (Prediction Interface)
In reality, this would be a REST API endpoint. Here we simulate it with a function that returns prediction probabilities.

In [None]:
def query_victim_api(X_queries, return_probabilities=True):
    """
    Simulates querying the victim model's API.
    
    Args:
        X_queries: Input samples to query
        return_probabilities: If True, return probability vectors; else return class labels
    
    Returns:
        predictions: Either probability matrix or class labels
    """
    if return_probabilities:
        # Return full probability distribution (more information leakage)
        return victim_model.predict_proba(X_queries)
    else:
        # Return only class label (less information, harder to extract)
        return victim_model.predict(X_queries)

# Test the API
sample_query = X_test[:5]
sample_predictions = query_victim_api(sample_query)

print("Example API Query Results (first 5 samples):")
print("Shape:", sample_predictions.shape)
print("\nProbability distributions:")
print(sample_predictions)
print("\n‚ö†Ô∏è Notice: These probabilities reveal decision boundaries!")

## üéØ Step 4: Attacker Generates Query Dataset
The attacker needs inputs to query. They can:
1.  Use synthetic data (random, adversarial, or sampled from distribution)
2.  Collect real examples from the wild
3.  Use transfer data from similar domains

Here we simulate synthetic queries.

In [None]:
# Attacker's query budget (number of API calls they're willing to make)
QUERY_BUDGET = 2000  # Start with 2000 queries

# Strategy 1: Random synthetic queries (uniform random in feature space)
# Note: Real attackers would use smarter strategies (active learning, etc.)
X_attacker_queries = np.random.uniform(
    low=X.min(), 
    high=X.max(), 
    size=(QUERY_BUDGET, X.shape[1])
)

# Strategy 2 (Optional): Mix in some real data if attacker has access
# For demonstration, we'll add some real samples
n_real = min(500, QUERY_BUDGET // 4)
real_indices = np.random.choice(len(X_test), n_real, replace=False)
X_attacker_queries[:n_real] = X_test[real_indices]

print(f"Attacker generated {QUERY_BUDGET} query samples")
print(f"Query dataset shape: {X_attacker_queries.shape}")
print(f"\nEstimated cost at $0.001/query: ${QUERY_BUDGET * 0.001:.2f}")

## üì° Step 5: Execute Extraction Attack (Query & Collect)
The attacker now queries the victim API and collects predictions.

In [None]:
print("üö® ATTACK IN PROGRESS: Querying victim API...")
print(f"Sending {QUERY_BUDGET} queries...\n")

# Collect predictions from victim model
y_attacker_soft_labels = query_victim_api(X_attacker_queries, return_probabilities=True)

# The attacker now has a training dataset: (X_attacker_queries, y_attacker_soft_labels)
print("‚úÖ Attack data collected!")
print(f"Collected {len(y_attacker_soft_labels)} prediction vectors")
print(f"Each vector contains {y_attacker_soft_labels.shape[1]} probability scores")
print("\nAttacker now has everything needed to train a surrogate model.")

## ü§ñ Step 6: Train Surrogate Model (The Replica)
Using the collected queries and predictions, the attacker trains their own model.

In [None]:
print("Training surrogate model using stolen predictions...\n")

# Option 1: Train on hard labels (argmax of probabilities)
y_attacker_hard_labels = y_attacker_soft_labels.argmax(axis=1)

# Train surrogate (can use different architecture)
surrogate_model = MLPClassifier(hidden_layer_sizes=(50,), max_iter=500, random_state=42)
surrogate_model.fit(X_attacker_queries, y_attacker_hard_labels)

print("‚úÖ Surrogate model trained!")
print(f"Total extraction cost: ${QUERY_BUDGET * 0.001:.2f}")
print(f"Victim's development cost: ~$1,000,000+ (estimated)")
print(f"\nROI for attacker: {1000000 / (QUERY_BUDGET * 0.001):.0f}x")

## üìä Step 7: Evaluate Extraction Success
We measure how well the surrogate replicates the victim's behavior.

In [None]:
# Evaluate surrogate on test set
y_surrogate_pred = surrogate_model.predict(X_test)
y_victim_pred_test = victim_model.predict(X_test)

# Metric 1: Surrogate accuracy (how well it predicts true labels)
surrogate_accuracy = accuracy_score(y_test, y_surrogate_pred)

# Metric 2: Fidelity (how often surrogate agrees with victim)
fidelity = accuracy_score(y_victim_pred_test, y_surrogate_pred)

# Metric 3: Accuracy gap
accuracy_gap = victim_accuracy - surrogate_accuracy

print("="*60)
print("EXTRACTION ATTACK RESULTS")
print("="*60)
print(f"Victim Accuracy (on true labels):     {victim_accuracy:.4f}")
print(f"Surrogate Accuracy (on true labels):  {surrogate_accuracy:.4f}")
print(f"Accuracy Gap:                          {accuracy_gap:.4f}")
print(f"\nüéØ Fidelity (agreement rate):          {fidelity:.4f}")
print(f"\nQueries used:                          {QUERY_BUDGET}")
print(f"Cost:                                  ${QUERY_BUDGET * 0.001:.2f}")
print("="*60)

if fidelity > 0.90:
    print("\nüö® ATTACK SUCCESSFUL: Surrogate achieves >90% fidelity!")
    print("The victim's IP has been effectively stolen.")
else:
    print("\n‚ö†Ô∏è Attack partially successful. More queries may improve fidelity.")

## üìà Step 8: Visualize Attack Effectiveness

In [None]:
# Create confusion matrix comparing victim vs surrogate predictions
cm = confusion_matrix(y_victim_pred_test, y_surrogate_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=range(10), yticklabels=range(10))
plt.title(f'Victim vs Surrogate Predictions\nFidelity: {fidelity:.2%}', fontsize=14)
plt.ylabel('Victim Prediction', fontsize=12)
plt.xlabel('Surrogate Prediction', fontsize=12)
plt.tight_layout()
plt.show()

print("\nDiagonal values = agreement between models")
print("Off-diagonal values = disagreements (extraction failures)")

## üî¨ Step 9: Query Budget Analysis
How does extraction fidelity change with query budget?

In [None]:
# Test different query budgets
budgets = [100, 250, 500, 1000, 2000, 5000]
fidelities = []

print("Testing extraction with different query budgets...\n")

for budget in budgets:
    # Generate queries
    X_queries = np.random.uniform(low=X.min(), high=X.max(), size=(budget, X.shape[1]))
    
    # Query victim
    y_queries = query_victim_api(X_queries, return_probabilities=True).argmax(axis=1)
    
    # Train surrogate
    temp_surrogate = MLPClassifier(hidden_layer_sizes=(50,), max_iter=500, random_state=42)
    temp_surrogate.fit(X_queries, y_queries)
    
    # Measure fidelity
    temp_pred = temp_surrogate.predict(X_test)
    temp_fidelity = accuracy_score(y_victim_pred_test, temp_pred)
    fidelities.append(temp_fidelity)
    
    print(f"Budget: {budget:5d} queries ‚Üí Fidelity: {temp_fidelity:.4f} (Cost: ${budget * 0.001:.2f})")

# Plot results
plt.figure(figsize=(10, 6))
plt.plot(budgets, fidelities, marker='o', linewidth=2, markersize=8)
plt.axhline(y=0.90, color='r', linestyle='--', label='90% Fidelity Threshold')
plt.xlabel('Query Budget', fontsize=12)
plt.ylabel('Extraction Fidelity', fontsize=12)
plt.title('Model Extraction: Fidelity vs Query Budget', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

print(f"\nüí° Insight: Fidelity increases with query budget.")
print(f"Even with {budgets[0]} queries (${budgets[0] * 0.001:.2f}), fidelity reaches {fidelities[0]:.2%}")

## üõ°Ô∏è Step 10: Defense Mechanisms (Discussion)

### How to Protect Against Model Extraction:

1.  **Query Limiting**: Rate limiting per user/IP
2.  **Prediction Perturbation**: Add random noise to outputs
3.  **Confidence Rounding**: Return rounded probabilities
4.  **Query Analysis**: Detect suspicious patterns (e.g., grid sampling)
5.  **Watermarking**: Embed triggers in model that reveal theft
6.  **Output Restrictions**: Return only class labels, not probabilities

### Trade-offs:
-   Stronger defenses ‚Üí Reduced utility for legitimate users
-   Weaker defenses ‚Üí Higher risk of IP theft

This is an active area of research in ML security.

## üìù Summary

### What We Demonstrated:
‚úÖ A victim model (Random Forest) was trained on proprietary data  
‚úÖ An attacker queried the model's API {QUERY_BUDGET} times  
‚úÖ Using only predictions, the attacker trained a surrogate model  
‚úÖ The surrogate achieved {fidelity:.1%} fidelity to the victim  
‚úÖ Total attack cost: ${QUERY_BUDGET * 0.001:.2f} vs millions in R&D  

### Key Takeaways:
1.  **IP Leakage**: Every prediction leaks information about your model
2.  **Economic Threat**: Attackers can steal models for pennies on the dollar
3.  **Silent Attack**: No database breach needed, just normal API usage
4.  **Scale Matters**: Larger query budgets ‚Üí higher fidelity extraction
5.  **Defense is Hard**: Protecting models without hurting usability is challenging

### Real-World Impact:
This attack has been demonstrated against:
-   Google Prediction API
-   Amazon ML
-   BigML
-   Face++ API
-   Various commercial MLaaS platforms

**Model extraction is not theoretical‚Äîit's a real threat to ML IP.**