# Robust Membership Inference Attack Tutorial

In [None]:
%matplotlib inline  

import random
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

# Ensure reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Dataset Setup (Challenger Model Dataset and Reference Model Data Split)

In [None]:
# Dataset Preparation
transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

# Split into training and shadow sets
train_size = int(0.5 * len(train_data))
shadow_size = len(train_data) - train_size
train_set, shadow_set = random_split(train_data, [train_size, shadow_size])

num_classes = 10

In [None]:
# Model Architecture: ResNet18
def get_resnet_model(num_classes):
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

In [None]:
print(get_resnet_model(10))

# Training/Loading Challenger Model

In [None]:
# Training Parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 25
batch_size = 256
learning_rate = 0.1
weight_decay = 5e-4
momentum = 0.9
num_classes = 10
checkpoint_dir = "./checkpoints_new4"  # Directory to save/load checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)

# Function to save checkpoint
def save_checkpoint(model, optimizer, epoch, path):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved at {path}")

# Function to load checkpoint
def load_checkpoint(model, optimizer, path):
    if os.path.isfile(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        # Move the model and optimizer states to the appropriate device
        model = model.to(device)
        for state in optimizer.state.values():
            if isinstance(state, torch.Tensor):
                state.data = state.data.to(device)
            elif isinstance(state, dict):
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

        start_epoch = checkpoint["epoch"]
        print(f"Checkpoint loaded from {path}, starting at epoch {start_epoch + 1}")
        return model, optimizer, start_epoch
    else:
        print(f"No checkpoint found at {path}, starting fresh.")
        return model, optimizer, 0

# Training Function
def train_model(model, train_loader, optimizer, criterion, epochs, test_loader=None, checkpoint_path=None):
    start_epoch = 0

    # Load from checkpoint if provided
    if checkpoint_path:
        model, optimizer, start_epoch = load_checkpoint(model, optimizer, checkpoint_path)

    # Ensure model is on the correct device
    model.to(device)

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=True)

        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update tqdm description
            loop.set_postfix(loss=loss.item(), accuracy=100. * correct / total)

        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {running_loss / len(train_loader):.4f}, "
              f"Train Accuracy: {100. * correct / total:.2f}%")

        # Save checkpoint at the end of each epoch
        if checkpoint_path:
            save_checkpoint(model, optimizer, epoch, checkpoint_path)

        # Evaluate on test set every 5 epochs
        if test_loader and (epoch + 1) % 5 == 0:
            evaluate_model(model, test_loader, epoch + 1)

    return model

# Evaluation Function
def evaluate_model(model, test_loader, epoch):
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        loop = tqdm(test_loader, desc=f"Evaluating at Epoch {epoch}", leave=True)
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100. * correct / total
    precision = precision_score(all_labels, all_predictions, average="weighted")
    recall = recall_score(all_labels, all_predictions, average="weighted")
    f1 = f1_score(all_labels, all_predictions, average="weighted")

    print(f"Test Results at Epoch {epoch} - Accuracy: {accuracy:.2f}%, "
          f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

In [None]:
# Main Training Process for Target Model
print("Training Target Model...")
target_model = get_resnet_model(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(target_model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
target_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Checkpoint path for the target model
target_checkpoint_path = os.path.join(checkpoint_dir, "target_model.pth")

target_model = train_model(target_model, target_loader, optimizer, criterion, epochs, test_loader, target_checkpoint_path)

# Robust MIA Setup

In [None]:
# Parameters
num_reference_models = 5  # Number of reference models
max_usage_ratio = 0.1  # Maximum usage of 10% of the shadow dataset
batch_size = 128  # Batch size for DataLoader
attack_mode = "offline"  # "offline" or "online"
seed = 42  # Random seed for reproducibility

# Set random seed
torch.manual_seed(seed)

# Function to generate reference datasets with max usage constraint
def generate_limited_reference_datasets(shadow_set, num_models, max_usage_ratio, include_target=False):
    """
    Generate reference datasets while limiting the total dataset usage.
    
    Args:
        shadow_set: The shadow dataset from which reference datasets are sampled.
        num_models: Number of reference models.
        max_usage_ratio: Maximum fraction of the dataset to use across all models.
        include_target: Whether to include target data in some models (for online attack).
    
    Returns:
        reference_datasets: List of datasets for reference models.
    """
    total_allowed_samples = int(max_usage_ratio * len(shadow_set))
    samples_per_model = total_allowed_samples // num_models  # Evenly distribute samples across models
    
    shadow_indices = list(range(len(shadow_set)))
    reference_datasets = []
    
    for i in range(num_models):
        if include_target and i % 2 == 0:  # IN models (only for online attack)
            subset_indices = random.sample(shadow_indices, samples_per_model)
        else:  # OUT models
            subset_indices = random.sample(shadow_indices, samples_per_model)
        
        reference_datasets.append(Subset(shadow_set, subset_indices))
    
    return reference_datasets

# Generate reference datasets
if attack_mode == "online":
    # Online attack: Generate both IN and OUT datasets
    reference_datasets = generate_limited_reference_datasets(
        shadow_set, num_reference_models, max_usage_ratio, include_target=True
    )
else:
    # Offline attack: Generate only OUT datasets
    reference_datasets = generate_limited_reference_datasets(
        shadow_set, num_reference_models, max_usage_ratio, include_target=False
    )

In [None]:
# Training Reference Models
reference_model_epochs = 10

print("\nTraining Reference Models...")
reference_models = []
for i, ref_dataset in enumerate(reference_datasets):
    print(f"\nTraining Reference Model {i + 1}/{len(reference_datasets)}...")
    ref_loader = DataLoader(ref_dataset, batch_size=batch_size, shuffle=True)
    ref_model = get_resnet_model(num_classes)
    ref_optimizer = optim.SGD(ref_model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    ref_checkpoint_path = os.path.join(checkpoint_dir, f"reference_model_{i + 1}.pth")
    ref_model = train_model(ref_model, ref_loader, ref_optimizer, criterion, reference_model_epochs, None, ref_checkpoint_path)
    reference_models.append(ref_model)

print("Training Complete.")

In [None]:
from torch.utils.data import DataLoader, Subset
import random

# Prepare member and non-member datasets
ref_datapoint_count = 1000  # Total points to sample for members/non-members
member_indices = random.sample(range(len(train_set)), ref_datapoint_count // 2)
non_member_indices = random.sample(range(len(test_data)), ref_datapoint_count // 2)

member_dataset = Subset(train_set, member_indices)
non_member_dataset = Subset(test_data, non_member_indices)

# DataLoaders for members and non-members
member_loader = DataLoader(member_dataset, batch_size=128, shuffle=True)
non_member_loader = DataLoader(non_member_dataset, batch_size=128, shuffle=True)

## Robust Membership Inference Attack

In [None]:
import torch
import numpy as np
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F


# def compute_P_x(reference_models, x, non_member_loader):
#     """
#     Compute P(x) using OUT reference models trained on non-members.

#     Args:
#         reference_models: List of OUT models (trained on non-members).
#         x: Target data point (dictionary with 'input' and 'label').
#         non_member_loader: DataLoader providing access to non-member data.
    
#     Returns:
#         P_x: Average probability of x in the population (non-member).
#     """
#     x_softmax_scores = []

#     for model in reference_models:
#         model.eval()
#         with torch.no_grad():
#             for inputs, _ in non_member_loader:
#                 inputs = inputs.to(next(model.parameters()).device)
#                 outputs = model(inputs)
#                 softmax_scores = F.softmax(outputs, dim=1)
#                 x_scores = softmax_scores[:, x['label']].cpu().numpy()  # Extract probabilities for x's label
#                 x_softmax_scores.extend(x_scores)

#     P_x = np.mean(x_softmax_scores)
#     return P_x

# P(x): Probability of x in the population
def compute_P_x(reference_models, x, non_member_loader):
    """
    Compute P(x) using OUT reference models trained on non-members.

    Args:
        reference_models: List of OUT models (trained on non-members).
        x: Target data point (dictionary with 'input' and 'label').
        non_member_loader: DataLoader providing access to non-member data.

    Returns:
        P_x: Average probability of x in the population (non-member).
    """
    x_softmax_scores = []

    for model in reference_models:
        model.eval()
        with torch.no_grad():
            x_input = x["input"].to(next(model.parameters()).device)  # Ensure input is on the same device as the model
            outputs = model(x_input)
            softmax_scores = F.softmax(outputs, dim=1)
            x_scores = softmax_scores[0, x['label']].item()  # Extract the probability for the specific class label
            x_softmax_scores.append(x_scores)

    P_x = np.mean(x_softmax_scores)
    return P_x



# P(Model | x): Likelihood of the model given x
def compute_P_Model_given_x(in_reference_models, x):
    """
    Compute P(Model | x) using IN reference models (trained with x included).

    Args:
        in_reference_models: List of IN reference models (trained with x).
        x: Target data point (dictionary with 'input' and 'label').

    Returns:
        P_Model_given_x: The likelihood of observing the model given x.
    """
    # Aggregate softmax scores from IN reference models
    in_softmax_scores = []

    for model in in_reference_models:
        model.eval()
        with torch.no_grad():
            x_input = x["input"].unsqueeze(0).to(next(model.parameters()).device)  # Add batch dimension
            outputs = model(x_input)
            softmax_scores = F.softmax(outputs, dim=1)
            in_softmax_scores.append(softmax_scores[0, x["label"]].item())  # Probability for the true label

    # Compute the mean probability across all IN reference models
    P_Model_given_x = np.mean(in_softmax_scores)
    return P_Model_given_x


# Likelihood Ratio Computation
def compute_likelihood_ratio(reference_models, in_reference_models, x, non_member_loader):
    """
    Compute the likelihood ratio LR(x) = P(Model | x) / P(x).

    Args:
        reference_models: List of OUT models (trained on non-members).
        in_reference_models: List of IN models (trained with x).
        x: Target data point (dictionary with 'input' and 'label').
        non_member_loader: DataLoader providing access to non-member data.
    
    Returns:
        likelihood_ratio: The likelihood ratio LR(x).
    """
    P_x = compute_P_x(reference_models, x, non_member_loader)
    P_Model_given_x = compute_P_Model_given_x(in_reference_models, x)
    likelihood_ratio = P_Model_given_x / (P_x + 1e-10)  # Add small value to avoid division by zero
    return likelihood_ratio



# Perform RMIA
def rmi_attack(target_model, in_reference_models, out_reference_models, member_loader, non_member_loader):
    """
    Perform the Robust Membership Inference Attack (RMIA).

    Args:
        target_model: The target model to attack.
        in_reference_models: List of IN reference models (trained with members).
        out_reference_models: List of OUT reference models (trained without members).
        member_loader: DataLoader providing access to member data.
        non_member_loader: DataLoader providing access to non-member data.

    Returns:
        member_scores: Likelihood ratio scores for member samples.
        non_member_scores: Likelihood ratio scores for non-member samples.
    """
    member_scores = []
    non_member_scores = []

    # Evaluate member samples
    print("Computing Likelihood Ratios for Members...")
    for inputs, labels in member_loader:
        for i in range(len(inputs)):
            x = {"input": inputs[i], "label": labels[i].item()}
            lrt_score = compute_likelihood_ratio(out_reference_models, in_reference_models, x, non_member_loader)
            member_scores.append(lrt_score)

    # Evaluate non-member samples
    print("Computing Likelihood Ratios for Non-Members...")
    for inputs, labels in non_member_loader:
        for i in range(len(inputs)):
            x = {"input": inputs[i], "label": labels[i].item()}
            lrt_score = compute_likelihood_ratio(out_reference_models, in_reference_models, x, non_member_loader)
            non_member_scores.append(lrt_score)

    return member_scores, non_member_scores


# Evaluate Attack
def evaluate_attack(member_scores, non_member_scores, threshold):
    """
    Evaluate the effectiveness of the RMIA.

    Args:
        member_scores: Likelihood ratio scores for member samples.
        non_member_scores: Likelihood ratio scores for non-member samples.
        threshold: Threshold for deciding membership.

    Returns:
        accuracy, precision, recall, F1-score for the attack.
    """
    # True labels: 1 for members, 0 for non-members
    true_labels = [1] * len(member_scores) + [0] * len(non_member_scores)
    predicted_labels = [1 if score > threshold else 0 for score in member_scores + non_member_scores]

    # Compute metrics
    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels)
    recall = recall_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels)

    print(f"Attack Results - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, "
          f"Recall: {recall:.4f}, F1 Score: {f1:.4f}")

    return accuracy, precision, recall, f1

In [None]:
%matplotlib inline  

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

def compute_roc(member_scores, non_member_scores):
    """
    Compute TPR, FPR, and thresholds for ROC curve.
    
    Args:
        member_scores: Likelihood ratio scores for member samples.
        non_member_scores: Likelihood ratio scores for non-member samples.
    
    Returns:
        fpr: False Positive Rate.
        tpr: True Positive Rate.
        thresholds: Thresholds used for the ROC curve.
        roc_auc: Area Under the ROC Curve.
    """
    # True labels: 1 for members, 0 for non-members
    true_labels = [1] * len(member_scores) + [0] * len(non_member_scores)
    scores = member_scores + non_member_scores

    # Compute ROC curve and AUC
    fpr, tpr, thresholds = roc_curve(true_labels, scores)
    roc_auc = auc(fpr, tpr)
    
    return fpr, tpr, thresholds, roc_auc

def plot_roc_curve(reference_models_list, member_loader, non_member_loader, target_model):
    """
    Plot ROC curves as the number of reference models varies.

    Args:
        reference_models_list: List of lists, each containing a varying number of reference models.
        member_loader: DataLoader for member samples.
        non_member_loader: DataLoader for non-member samples.
        target_model: Target model to attack.
    """
    if not isinstance(reference_models_list, list) or not all(isinstance(models, list) for models in reference_models_list):
        raise ValueError("reference_models_list must be a list of lists of models")

    plt.figure(figsize=(10, 7))

    for i, reference_models in enumerate(reference_models_list):
        if not reference_models:
            raise ValueError(f"Reference models for index {i} are empty or invalid.")
        
        print(f"Evaluating for {len(reference_models)} reference models...")
        member_scores, non_member_scores = rmi_attack(
            target_model,
            reference_models[:len(reference_models)//2],  # IN reference models
            reference_models[len(reference_models)//2:],  # OUT reference models
            member_loader,
            non_member_loader
        )

        # Compute ROC
        fpr, tpr, thresholds, roc_auc = compute_roc(member_scores, non_member_scores)

        # Plot ROC curve
        plt.plot(fpr, tpr, label=f"{len(reference_models)} Models (AUC = {roc_auc:.4f})")

    plt.plot([0, 1], [0, 1], "k--")  # Diagonal line
    plt.title("ROC Curves for RMIA with Varying Reference Models")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")
    plt.grid()
    plt.show()

In [None]:
num_models = len(reference_models)
in_reference_models = reference_models[:num_models // 2]
out_reference_models = reference_models[num_models // 2:]


# Perform Attack and Evaluate
print("Performing RMIA...")
member_scores, non_member_scores = rmi_attack(
    target_model, in_reference_models, out_reference_models, member_loader, non_member_loader
)

# Use a dynamic threshold (e.g., median of all scores)
all_scores = member_scores + non_member_scores
threshold = np.median(all_scores)

# Evaluate the attack
evaluate_attack(member_scores, non_member_scores, threshold)

In [None]:
len(reference_models)

In [None]:
%matplotlib inline
plot_roc_curve([reference_models], member_loader, non_member_loader, target_model)